MLIR 22.0.0git
TosaToSCF.cpp
Go to the documentation of this file.
1//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
18
19using namespace mlir;
20using namespace tosa;
21
22static void inlineIfCase(Region &srcRegion, Region &dstRegion,
23 OperandRange operands, PatternRewriter &rewriter) {
24 rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
25 rewriter.eraseBlock(&dstRegion.back());
26
27 Block *headBlock = &dstRegion.front();
28 for (auto it : llvm::zip(headBlock->getArguments(), operands))
29 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
30
31 auto yield = cast<YieldOp>(headBlock->getTerminator());
32 rewriter.setInsertionPoint(yield);
33 scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
34 rewriter.eraseOp(yield);
35
36 headBlock->eraseArguments(0, headBlock->getNumArguments());
37}
38
39static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
40 PatternRewriter &rewriter, bool isCond) {
41 rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
42 rewriter.eraseBlock(&dstRegion.back());
43
44 Block *headBlock = &dstRegion.front();
45
46 auto yield = cast<YieldOp>(headBlock->getTerminator());
47 rewriter.setInsertionPoint(yield);
48 if (isCond) {
49 auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
50 yield.getOperand(0));
51 scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
52 headBlock->getArguments());
53 } else {
54 rewriter.setInsertionPoint(yield);
55 scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
56 }
57 rewriter.eraseOp(yield);
58}
59
60namespace {
61
62class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
63public:
64 using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
65
66 LogicalResult matchAndRewrite(tosa::IfOp op,
67 PatternRewriter &rewriter) const final {
68 auto condition =
69 tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
70 auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
71 condition, true);
72
73 inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
74 rewriter);
75 inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
76 rewriter);
77
78 rewriter.replaceOp(op, newIf.getResults());
79 return success();
80 }
81};
82
83class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
84 static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
85 int64_t dim) {
86 return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
87 }
88
89 static Value createIndexConst(OpBuilder &builder, Location loc,
90 int64_t value) {
91 return arith::ConstantIndexOp::create(builder, loc, value);
92 }
93
94public:
95 using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
96
97 LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
98 PatternRewriter &rewriter) const final {
99 auto valuesIn = scatter.getValuesIn();
100 auto indices = scatter.getIndices();
101 auto input = scatter.getInput();
102 auto loc = scatter.getLoc();
103
104 // N, W, C are chosen to match the TOSA spec
105 auto dimN = createTensorDim(rewriter, loc, input, 0);
106 auto dimW = createTensorDim(rewriter, loc, input, 1);
107 auto dimC = createTensorDim(rewriter, loc, input, 2);
108
109 auto zero = createIndexConst(rewriter, loc, 0);
110 auto one = createIndexConst(rewriter, loc, 1);
111
112 // Loop bounds
113 auto lbs = llvm::SmallVector<Value>(2, zero);
114 auto steps = llvm::SmallVector<Value>(2, one);
115 auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
116
117 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
119 auto n = ivs[0];
120
121 // Read the index and cast it to index type
122 auto index = tensor::ExtractOp::create(builder, loc, indices, ivs);
123 auto castIndex = arith::IndexCastOp::create(
124 builder, loc, builder.getIndexType(), index);
125
126 // Offset, sizes, and strides for the input tensor
127 auto inputOffset = llvm::to_vector(ivs);
128 inputOffset.push_back(zero);
129
130 llvm::SmallVector<Value> sizes = {one, one, dimC};
131 llvm::SmallVector<Value> strides = {one, one, one};
132
133 auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
134 inputOffset, sizes, strides);
135
136 // Insert the slice into the output accumulator tensor.
137 llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
138 auto updated = tensor::InsertSliceOp::create(
139 builder, loc, slice, args[0], outputOffset, sizes, strides);
140
141 return {updated};
142 };
143
144 auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
145 ValueRange{valuesIn}, buildBody);
146 rewriter.replaceOp(scatter, loops.results);
147
148 return success();
149 }
150};
151
152class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
153public:
154 using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
155
156 LogicalResult matchAndRewrite(tosa::WhileOp op,
157 PatternRewriter &rewriter) const final {
158 auto newWhile = scf::WhileOp::create(
159 rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
160 rewriter.createBlock(&newWhile.getBefore());
161 rewriter.createBlock(&newWhile.getAfter());
162
163 inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
164 inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
165
166 rewriter.replaceOp(op, newWhile.getResults());
167
168 return success();
169 }
170};
171
172} // namespace
173
176 patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
177 patterns->getContext());
178}
return success()
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
Definition TosaToSCF.cpp:22
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Definition TosaToSCF.cpp:39
Block represents an ordered list of Operations.
Definition Block.h:33
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
BlockArgListType getArguments()
Definition Block.h:87
IndexType getIndexType()
Definition Builders.cpp:51
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Block & back()
Definition Region.h:64
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:837
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...