MLIR  16.0.0git
TosaToSCF.cpp
Go to the documentation of this file.
1 //===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the SCF dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 using namespace tosa;
23 
24 static void inlineIfCase(Region &srcRegion, Region &dstRegion,
25  OperandRange operands, PatternRewriter &rewriter) {
26  rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
27  rewriter.eraseBlock(&dstRegion.back());
28 
29  Block *headBlock = &dstRegion.front();
30  for (auto it : llvm::zip(headBlock->getArguments(), operands))
31  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
32 
33  auto yield = cast<YieldOp>(headBlock->getTerminator());
34  rewriter.setInsertionPoint(yield);
35  rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
36  rewriter.eraseOp(yield);
37 
38  headBlock->eraseArguments(0, headBlock->getNumArguments());
39 }
40 
41 static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
42  PatternRewriter &rewriter, bool isCond) {
43  rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
44  rewriter.eraseBlock(&dstRegion.back());
45 
46  Block *headBlock = &dstRegion.front();
47 
48  auto yield = cast<YieldOp>(headBlock->getTerminator());
49  rewriter.setInsertionPoint(yield);
50  if (isCond) {
51  auto condition =
52  rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
53  rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
54  headBlock->getArguments());
55  } else {
56  rewriter.setInsertionPoint(yield);
57  rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
58  }
59  rewriter.eraseOp(yield);
60 }
61 
62 namespace {
63 
64 class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
65 public:
67 
68  LogicalResult matchAndRewrite(tosa::IfOp op,
69  PatternRewriter &rewriter) const final {
70  auto condition =
71  rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
72  auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
73  condition, true);
74 
75  inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
76  rewriter);
77  inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
78  rewriter);
79 
80  rewriter.replaceOp(op, newIf.getResults());
81  return success();
82  }
83 };
84 
85 class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
86 public:
88 
89  LogicalResult matchAndRewrite(tosa::WhileOp op,
90  PatternRewriter &rewriter) const final {
91  auto newWhile = rewriter.create<scf::WhileOp>(
92  op.getLoc(), op.getResultTypes(), op.getInputs());
93  rewriter.createBlock(&newWhile.getBefore());
94  rewriter.createBlock(&newWhile.getAfter());
95 
96  inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
97  inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
98 
99  rewriter.replaceOp(op, newWhile.getResults());
100 
101  return success();
102  }
103 };
104 
105 } // namespace
106 
108  RewritePatternSet *patterns) {
109  patterns->add<IfOpConverter>(patterns->getContext());
110  patterns->add<WhileOpConverter>(patterns->getContext());
111 }
static void inlineIfCase(Region &srcRegion, Region &dstRegion, OperandRange operands, PatternRewriter &rewriter)
Definition: TosaToSCF.cpp:24
static void inlineWhileCase(Region &srcRegion, Region &dstRegion, PatternRewriter &rewriter, bool isCond)
Definition: TosaToSCF.cpp:41
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:117
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:189
BlockArgListType getArguments()
Definition: Block.h:76
void replaceAllUsesWith(ValueT &&newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: UseDefLists.h:188
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & back()
Definition: Region.h:64
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void populateTosaToSCFConversionPatterns(RewritePatternSet *patterns)
Definition: TosaToSCF.cpp:107
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356