28 for (
auto it : llvm::zip(headBlock->
getArguments(), operands))
33 scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
49 auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
51 scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
55 scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
64 using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
66 LogicalResult matchAndRewrite(tosa::IfOp op,
67 PatternRewriter &rewriter)
const final {
69 tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
70 auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
73 inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
75 inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
78 rewriter.replaceOp(op, newIf.getResults());
84 static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
86 return builder.
createOrFold<tensor::DimOp>(loc, tensor, dim);
89 static Value createIndexConst(OpBuilder &builder, Location loc,
95 using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
97 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
98 PatternRewriter &rewriter)
const final {
99 auto valuesIn = scatter.getValuesIn();
100 auto indices = scatter.getIndices();
101 auto input = scatter.getInput();
102 auto loc = scatter.getLoc();
105 auto dimN = createTensorDim(rewriter, loc, input, 0);
106 auto dimW = createTensorDim(rewriter, loc, input, 1);
107 auto dimC = createTensorDim(rewriter, loc, input, 2);
109 auto zero = createIndexConst(rewriter, loc, 0);
110 auto one = createIndexConst(rewriter, loc, 1);
113 auto lbs = llvm::SmallVector<Value>(2, zero);
114 auto steps = llvm::SmallVector<Value>(2, one);
115 auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
117 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange ivs,
122 auto index = tensor::ExtractOp::create(builder, loc,
indices, ivs);
123 auto castIndex = arith::IndexCastOp::create(
127 auto inputOffset = llvm::to_vector(ivs);
128 inputOffset.push_back(zero);
130 llvm::SmallVector<Value> sizes = {one, one, dimC};
131 llvm::SmallVector<Value> strides = {one, one, one};
133 auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
134 inputOffset, sizes, strides);
137 llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
138 auto updated = tensor::InsertSliceOp::create(
139 builder, loc, slice, args[0], outputOffset, sizes, strides);
146 rewriter.replaceOp(scatter, loops.results);
154 using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
156 LogicalResult matchAndRewrite(tosa::WhileOp op,
157 PatternRewriter &rewriter)
const final {
158 auto newWhile = scf::WhileOp::create(
159 rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
160 rewriter.createBlock(&newWhile.getBefore());
161 rewriter.createBlock(&newWhile.getAfter());
163 inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter,
true);
164 inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter,
false);
166 rewriter.replaceOp(op, newWhile.getResults());
176 patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
BlockArgListType getArguments()
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class implements the operand iterators for the Operation class.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...