28 for (
auto it : llvm::zip(headBlock->
getArguments(), operands))
33 scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
49 auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
51 scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
55 scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
66 LogicalResult matchAndRewrite(tosa::IfOp op,
69 tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
70 auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
73 inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
75 inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
78 rewriter.replaceOp(op, newIf.getResults());
86 return builder.
createOrFold<tensor::DimOp>(loc, tensor, dim);
97 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
99 auto valuesIn = scatter.getValuesIn();
100 auto indices = scatter.getIndices();
101 auto input = scatter.getInput();
102 auto loc = scatter.getLoc();
105 auto dimN = createTensorDim(rewriter, loc, input, 0);
106 auto dimW = createTensorDim(rewriter, loc, input, 1);
107 auto dimC = createTensorDim(rewriter, loc, input, 2);
109 auto zero = createIndexConst(rewriter, loc, 0);
110 auto one = createIndexConst(rewriter, loc, 1);
122 auto index = tensor::ExtractOp::create(builder, loc, indices, ivs);
123 auto castIndex = arith::IndexCastOp::create(
127 auto inputOffset = llvm::to_vector(ivs);
128 inputOffset.push_back(zero);
133 auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
134 inputOffset, sizes, strides);
138 auto updated = tensor::InsertSliceOp::create(
139 builder, loc, slice, args[0], outputOffset, sizes, strides);
146 rewriter.replaceOp(scatter, loops.results);
156 LogicalResult matchAndRewrite(tosa::WhileOp op,
158 auto newWhile = scf::WhileOp::create(
159 rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
160 rewriter.createBlock(&newWhile.getBefore());
161 rewriter.createBlock(&newWhile.getAfter());
163 inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter,
true);
164 inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter,
false);
166 rewriter.replaceOp(op, newWhile.getResults());
176 patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class implements the operand iterators for the Operation class.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...