MLIR  22.0.0git
TosaToSCF.cpp
Go to the documentation of this file.
1 //===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace tosa;
21 
22 static void inlineIfCase(Region &srcRegion, Region &dstRegion,
23  OperandRange operands, PatternRewriter &rewriter) {
24  rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
25  rewriter.eraseBlock(&dstRegion.back());
26 
27  Block *headBlock = &dstRegion.front();
28  for (auto it : llvm::zip(headBlock->getArguments(), operands))
29  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
30 
31  auto yield = cast<YieldOp>(headBlock->getTerminator());
32  rewriter.setInsertionPoint(yield);
33  scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
34  rewriter.eraseOp(yield);
35 
36  headBlock->eraseArguments(0, headBlock->getNumArguments());
37 }
38 
39 static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
40  PatternRewriter &rewriter, bool isCond) {
41  rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
42  rewriter.eraseBlock(&dstRegion.back());
43 
44  Block *headBlock = &dstRegion.front();
45 
46  auto yield = cast<YieldOp>(headBlock->getTerminator());
47  rewriter.setInsertionPoint(yield);
48  if (isCond) {
49  auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(),
50  yield.getOperand(0));
51  scf::ConditionOp::create(rewriter, yield.getLoc(), condition,
52  headBlock->getArguments());
53  } else {
54  rewriter.setInsertionPoint(yield);
55  scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs());
56  }
57  rewriter.eraseOp(yield);
58 }
59 
60 namespace {
61 
62 class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
63 public:
65 
66  LogicalResult matchAndRewrite(tosa::IfOp op,
67  PatternRewriter &rewriter) const final {
68  auto condition =
69  tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition());
70  auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(),
71  condition, true);
72 
73  inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
74  rewriter);
75  inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
76  rewriter);
77 
78  rewriter.replaceOp(op, newIf.getResults());
79  return success();
80  }
81 };
82 
83 class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
84  static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
85  int64_t dim) {
86  return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
87  }
88 
89  static Value createIndexConst(OpBuilder &builder, Location loc,
90  int64_t value) {
91  return arith::ConstantIndexOp::create(builder, loc, value);
92  }
93 
94 public:
96 
97  LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
98  PatternRewriter &rewriter) const final {
99  auto valuesIn = scatter.getValuesIn();
100  auto indices = scatter.getIndices();
101  auto input = scatter.getInput();
102  auto loc = scatter.getLoc();
103 
104  // N, W, C are chosen to match the TOSA spec
105  auto dimN = createTensorDim(rewriter, loc, input, 0);
106  auto dimW = createTensorDim(rewriter, loc, input, 1);
107  auto dimC = createTensorDim(rewriter, loc, input, 2);
108 
109  auto zero = createIndexConst(rewriter, loc, 0);
110  auto one = createIndexConst(rewriter, loc, 1);
111 
112  // Loop bounds
113  auto lbs = llvm::SmallVector<Value>(2, zero);
114  auto steps = llvm::SmallVector<Value>(2, one);
115  auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
116 
117  auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
118  ValueRange args) -> scf::ValueVector {
119  auto n = ivs[0];
120 
121  // Read the index and cast it to index type
122  auto index = tensor::ExtractOp::create(builder, loc, indices, ivs);
123  auto castIndex = arith::IndexCastOp::create(
124  builder, loc, builder.getIndexType(), index);
125 
126  // Offset, sizes, and strides for the input tensor
127  auto inputOffset = llvm::to_vector(ivs);
128  inputOffset.push_back(zero);
129 
130  llvm::SmallVector<Value> sizes = {one, one, dimC};
131  llvm::SmallVector<Value> strides = {one, one, one};
132 
133  auto slice = tensor::ExtractSliceOp::create(builder, loc, input,
134  inputOffset, sizes, strides);
135 
136  // Insert the slice into the output accumulator tensor.
137  llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
138  auto updated = tensor::InsertSliceOp::create(
139  builder, loc, slice, args[0], outputOffset, sizes, strides);
140 
141  return {updated};
142  };
143 
144  auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
145  ValueRange{valuesIn}, buildBody);
146  rewriter.replaceOp(scatter, loops.results);
147 
148  return success();
149  }
150 };
151 
152 class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
153 public:
155 
156  LogicalResult matchAndRewrite(tosa::WhileOp op,
157  PatternRewriter &rewriter) const final {
158  auto newWhile = scf::WhileOp::create(
159  rewriter, op.getLoc(), op.getResultTypes(), op.getInputList());
160  rewriter.createBlock(&newWhile.getBefore());
161  rewriter.createBlock(&newWhile.getAfter());
162 
163  inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
164  inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
165 
166  rewriter.replaceOp(op, newWhile.getResults());
167 
168  return success();
169  }
170 };
171 
172 } // namespace
173 
176  patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
177  patterns->getContext());
178 }
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
Definition: TosaToSCF.cpp:22
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Definition: TosaToSCF.cpp:39
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
IndexType getIndexType()
Definition: Builders.cpp:50
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:211
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:575
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
Block & front()
Definition: Region.h:65
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:707
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Definition: TosaToSCF.cpp:174
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314