30 for (
auto it : llvm::zip(headBlock->
getArguments(), operands))
35 rewriter.
create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
52 rewriter.
create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
53 rewriter.
create<scf::ConditionOp>(yield.getLoc(), condition,
57 rewriter.
create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
68 LogicalResult matchAndRewrite(tosa::IfOp op,
71 rewriter.
create<tensor::ExtractOp>(op.getLoc(), op.getCond());
72 auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
75 inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
77 inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
80 rewriter.replaceOp(op, newIf.getResults());
88 return builder.
createOrFold<tensor::DimOp>(loc, tensor, dim);
93 return builder.
create<arith::ConstantIndexOp>(loc, value);
99 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
101 auto valuesIn = scatter.getValuesIn();
102 auto indices = scatter.getIndices();
103 auto input = scatter.getInput();
104 auto loc = scatter.getLoc();
107 auto dimN = createTensorDim(rewriter, loc, input, 0);
108 auto dimW = createTensorDim(rewriter, loc, input, 1);
109 auto dimC = createTensorDim(rewriter, loc, input, 2);
111 auto zero = createIndexConst(rewriter, loc, 0);
112 auto one = createIndexConst(rewriter, loc, 1);
124 auto index = builder.
create<tensor::ExtractOp>(loc, indices, ivs);
125 auto castIndex = builder.
create<arith::IndexCastOp>(
129 auto inputOffset = llvm::to_vector(ivs);
130 inputOffset.push_back(zero);
135 auto slice = builder.
create<tensor::ExtractSliceOp>(
136 loc, input, inputOffset, sizes, strides);
140 auto updated = builder.
create<tensor::InsertSliceOp>(
141 loc, slice, args[0], outputOffset, sizes, strides);
148 rewriter.replaceOp(scatter, loops.results);
158 LogicalResult matchAndRewrite(tosa::WhileOp op,
160 auto newWhile = rewriter.
create<scf::WhileOp>(
161 op.getLoc(), op.getResultTypes(), op.getInputs());
162 rewriter.createBlock(&newWhile.getBefore());
163 rewriter.createBlock(&newWhile.getAfter());
168 rewriter.replaceOp(op, newWhile.getResults());
178 patterns->
add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...