MLIR  19.0.0git
TosaToSCF.cpp
Go to the documentation of this file.
1 //===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/IRMapping.h"
18 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 using namespace tosa;
23 
24 static void inlineIfCase(Region &srcRegion, Region &dstRegion,
25  OperandRange operands, PatternRewriter &rewriter) {
26  rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
27  rewriter.eraseBlock(&dstRegion.back());
28 
29  Block *headBlock = &dstRegion.front();
30  for (auto it : llvm::zip(headBlock->getArguments(), operands))
31  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
32 
33  auto yield = cast<YieldOp>(headBlock->getTerminator());
34  rewriter.setInsertionPoint(yield);
35  rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
36  rewriter.eraseOp(yield);
37 
38  headBlock->eraseArguments(0, headBlock->getNumArguments());
39 }
40 
41 static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
42  PatternRewriter &rewriter, bool isCond) {
43  rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
44  rewriter.eraseBlock(&dstRegion.back());
45 
46  Block *headBlock = &dstRegion.front();
47 
48  auto yield = cast<YieldOp>(headBlock->getTerminator());
49  rewriter.setInsertionPoint(yield);
50  if (isCond) {
51  auto condition =
52  rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
53  rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
54  headBlock->getArguments());
55  } else {
56  rewriter.setInsertionPoint(yield);
57  rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
58  }
59  rewriter.eraseOp(yield);
60 }
61 
62 namespace {
63 
64 class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
65 public:
67 
68  LogicalResult matchAndRewrite(tosa::IfOp op,
69  PatternRewriter &rewriter) const final {
70  auto condition =
71  rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
72  auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
73  condition, true);
74 
75  inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
76  rewriter);
77  inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
78  rewriter);
79 
80  rewriter.replaceOp(op, newIf.getResults());
81  return success();
82  }
83 };
84 
85 class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
86  static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
87  int64_t dim) {
88  return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
89  }
90 
91  static Value createIndexConst(OpBuilder &builder, Location loc,
92  int64_t value) {
93  return builder.create<arith::ConstantIndexOp>(loc, value);
94  }
95 
96 public:
98 
99  LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
100  PatternRewriter &rewriter) const final {
101  auto valuesIn = scatter.getValuesIn();
102  auto indices = scatter.getIndices();
103  auto input = scatter.getInput();
104  auto loc = scatter.getLoc();
105 
106  // N, W, C are chosen to match the TOSA spec
107  auto dimN = createTensorDim(rewriter, loc, input, 0);
108  auto dimW = createTensorDim(rewriter, loc, input, 1);
109  auto dimC = createTensorDim(rewriter, loc, input, 2);
110 
111  auto zero = createIndexConst(rewriter, loc, 0);
112  auto one = createIndexConst(rewriter, loc, 1);
113 
114  // Loop bounds
115  auto lbs = llvm::SmallVector<Value>(2, zero);
116  auto steps = llvm::SmallVector<Value>(2, one);
117  auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
118 
119  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
120  ValueRange args) -> scf::ValueVector {
121  auto n = ivs[0];
122 
123  // Read the index and cast it to index type
124  auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
125  auto castIndex = builder.create<arith::IndexCastOp>(
126  loc, builder.getIndexType(), index);
127 
128  // Offset, sizes, and strides for the input tensor
129  auto inputOffset = llvm::to_vector(ivs);
130  inputOffset.push_back(zero);
131 
132  llvm::SmallVector<Value> sizes = {one, one, dimC};
133  llvm::SmallVector<Value> strides = {one, one, one};
134 
135  auto slice = builder.create<tensor::ExtractSliceOp>(
136  loc, input, inputOffset, sizes, strides);
137 
138  // Insert the slice into the output accumulator tensor.
139  llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
140  auto updated = builder.create<tensor::InsertSliceOp>(
141  loc, slice, args[0], outputOffset, sizes, strides);
142 
143  return {updated};
144  };
145 
146  auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
147  ValueRange{valuesIn}, buildBody);
148  rewriter.replaceOp(scatter, loops.results);
149 
150  return success();
151  }
152 };
153 
154 class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
155 public:
157 
158  LogicalResult matchAndRewrite(tosa::WhileOp op,
159  PatternRewriter &rewriter) const final {
160  auto newWhile = rewriter.create<scf::WhileOp>(
161  op.getLoc(), op.getResultTypes(), op.getInputs());
162  rewriter.createBlock(&newWhile.getBefore());
163  rewriter.createBlock(&newWhile.getAfter());
164 
165  inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
166  inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
167 
168  rewriter.replaceOp(op, newWhile.getResults());
169 
170  return success();
171  }
172 };
173 
174 } // namespace
175 
177  RewritePatternSet *patterns) {
178  patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
179  patterns->getContext());
180 }
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
Definition: TosaToSCF.cpp:24
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Definition: TosaToSCF.cpp:41
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
BlockArgListType getArguments()
Definition: Block.h:84
IndexType getIndexType()
Definition: Builders.cpp:71
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:582
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Definition: TosaToSCF.cpp:176
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358