MLIR  22.0.0git
ForToWhile.cpp
Go to the documentation of this file.
1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/PatternMatch.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORTOWHILELOOP
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using scf::ForOp;
28 using scf::WhileOp;
29 
30 namespace {
31 
32 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
34 
35  LogicalResult matchAndRewrite(ForOp forOp,
36  PatternRewriter &rewriter) const override {
37  // Generate type signature for the loop-carried values. The induction
38  // variable is placed first, followed by the forOp.iterArgs.
39  SmallVector<Type> lcvTypes;
40  SmallVector<Location> lcvLocs;
41  lcvTypes.push_back(forOp.getInductionVar().getType());
42  lcvLocs.push_back(forOp.getInductionVar().getLoc());
43  for (Value value : forOp.getInitArgs()) {
44  lcvTypes.push_back(value.getType());
45  lcvLocs.push_back(value.getLoc());
46  }
47 
48  // Build scf.WhileOp
49  SmallVector<Value> initArgs;
50  initArgs.push_back(forOp.getLowerBound());
51  llvm::append_range(initArgs, forOp.getInitArgs());
52  auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
53  forOp->getAttrs());
54 
55  // 'before' region contains the loop condition and forwarding of iteration
56  // arguments to the 'after' region.
57  auto *beforeBlock = rewriter.createBlock(
58  &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
59  rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
60  arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
61  ? arith::CmpIPredicate::ult
62  : arith::CmpIPredicate::slt;
63  auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
64  beforeBlock->getArgument(0),
65  forOp.getUpperBound());
66  scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
67  beforeBlock->getArguments());
68 
69  // Inline for-loop body into an executeRegion operation in the "after"
70  // region. The return type of the execRegionOp does not contain the
71  // iv - yields in the source for-loop contain only iterArgs.
72  auto *afterBlock = rewriter.createBlock(
73  &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
74 
75  // Add induction variable incrementation
76  rewriter.setInsertionPointToEnd(afterBlock);
77  auto ivIncOp =
78  arith::AddIOp::create(rewriter, whileOp.getLoc(),
79  afterBlock->getArgument(0), forOp.getStep());
80 
81  // Rewrite uses of the for-loop block arguments to the new while-loop
82  // "after" arguments
83  for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
84  rewriter.replaceAllUsesWith(barg.value(),
85  afterBlock->getArgument(barg.index()));
86 
87  // Inline for-loop body operations into 'after' region.
88  for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
89  rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
90 
91  // Add incremented IV to yield operations
92  for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
93  SmallVector<Value> yieldOperands = yieldOp.getOperands();
94  yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
95  rewriter.modifyOpInPlace(yieldOp,
96  [&]() { yieldOp->setOperands(yieldOperands); });
97  }
98 
99  // We cannot do a direct replacement of the forOp since the while op returns
100  // an extra value (the induction variable escapes the loop through being
101  // carried in the set of iterargs). Instead, rewrite uses of the forOp
102  // results.
103  for (const auto &arg : llvm::enumerate(forOp.getResults()))
104  rewriter.replaceAllUsesWith(arg.value(),
105  whileOp.getResult(arg.index() + 1));
106 
107  rewriter.eraseOp(forOp);
108  return success();
109  }
110 };
111 
112 struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
113  void runOnOperation() override {
114  auto *parentOp = getOperation();
115  MLIRContext *ctx = parentOp->getContext();
117  patterns.add<ForLoopLoweringPattern>(ctx);
118  (void)applyPatternsGreedily(parentOp, std::move(patterns));
119  }
120 };
121 } // namespace
122 
123 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
124  return std::make_unique<ForToWhileLoop>();
125 }
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createForToWhileLoopPass()
Definition: ForToWhile.cpp:123
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314