MLIR  16.0.0git
ForToWhile.cpp
Go to the documentation of this file.
1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
18 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace llvm;
22 using namespace mlir;
23 using scf::ForOp;
24 using scf::WhileOp;
25 
26 namespace {
27 
28 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
30 
31  LogicalResult matchAndRewrite(ForOp forOp,
32  PatternRewriter &rewriter) const override {
33  // Generate type signature for the loop-carried values. The induction
34  // variable is placed first, followed by the forOp.iterArgs.
35  SmallVector<Type> lcvTypes;
36  SmallVector<Location> lcvLocs;
37  lcvTypes.push_back(forOp.getInductionVar().getType());
38  lcvLocs.push_back(forOp.getInductionVar().getLoc());
39  for (Value value : forOp.getInitArgs()) {
40  lcvTypes.push_back(value.getType());
41  lcvLocs.push_back(value.getLoc());
42  }
43 
44  // Build scf.WhileOp
45  SmallVector<Value> initArgs;
46  initArgs.push_back(forOp.getLowerBound());
47  llvm::append_range(initArgs, forOp.getInitArgs());
48  auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
49  forOp->getAttrs());
50 
51  // 'before' region contains the loop condition and forwarding of iteration
52  // arguments to the 'after' region.
53  auto *beforeBlock = rewriter.createBlock(
54  &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
55  rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
56  auto cmpOp = rewriter.create<arith::CmpIOp>(
57  whileOp.getLoc(), arith::CmpIPredicate::slt,
58  beforeBlock->getArgument(0), forOp.getUpperBound());
59  rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
60  beforeBlock->getArguments());
61 
62  // Inline for-loop body into an executeRegion operation in the "after"
63  // region. The return type of the execRegionOp does not contain the
64  // iv - yields in the source for-loop contain only iterArgs.
65  auto *afterBlock = rewriter.createBlock(
66  &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
67 
68  // Add induction variable incrementation
69  rewriter.setInsertionPointToEnd(afterBlock);
70  auto ivIncOp = rewriter.create<arith::AddIOp>(
71  whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
72 
73  // Rewrite uses of the for-loop block arguments to the new while-loop
74  // "after" arguments
75  for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
76  barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
77 
78  // Inline for-loop body operations into 'after' region.
79  for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
80  arg.moveBefore(afterBlock, afterBlock->end());
81 
82  // Add incremented IV to yield operations
83  for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
84  SmallVector<Value> yieldOperands = yieldOp.getOperands();
85  yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
86  yieldOp->setOperands(yieldOperands);
87  }
88 
89  // We cannot do a direct replacement of the forOp since the while op returns
90  // an extra value (the induction variable escapes the loop through being
91  // carried in the set of iterargs). Instead, rewrite uses of the forOp
92  // results.
93  for (const auto &arg : llvm::enumerate(forOp.getResults()))
94  arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
95 
96  rewriter.eraseOp(forOp);
97  return success();
98  }
99 };
100 
101 struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
102  void runOnOperation() override {
103  auto *parentOp = getOperation();
104  MLIRContext *ctx = parentOp->getContext();
105  RewritePatternSet patterns(ctx);
106  patterns.add<ForLoopLoweringPattern>(ctx);
107  (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
108  }
109 };
110 } // namespace
111 
112 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
113  return std::make_unique<ForToWhileLoop>();
114 }
Include the generated interface declarations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:221
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static constexpr const bool value
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided &#39;values&#39;.
Definition: Operation.h:203
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
std::unique_ptr< Pass > createForToWhileLoopPass()
Definition: ForToWhile.cpp:112
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:382
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:377
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.