MLIR  22.0.0git
ForToWhile.cpp
Go to the documentation of this file.
1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/PatternMatch.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORTOWHILELOOP
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace llvm;
27 using namespace mlir;
28 using scf::ForOp;
29 using scf::WhileOp;
30 
31 namespace {
32 
33 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
35 
36  LogicalResult matchAndRewrite(ForOp forOp,
37  PatternRewriter &rewriter) const override {
38  // Generate type signature for the loop-carried values. The induction
39  // variable is placed first, followed by the forOp.iterArgs.
40  SmallVector<Type> lcvTypes;
41  SmallVector<Location> lcvLocs;
42  lcvTypes.push_back(forOp.getInductionVar().getType());
43  lcvLocs.push_back(forOp.getInductionVar().getLoc());
44  for (Value value : forOp.getInitArgs()) {
45  lcvTypes.push_back(value.getType());
46  lcvLocs.push_back(value.getLoc());
47  }
48 
49  // Build scf.WhileOp
50  SmallVector<Value> initArgs;
51  initArgs.push_back(forOp.getLowerBound());
52  llvm::append_range(initArgs, forOp.getInitArgs());
53  auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
54  forOp->getAttrs());
55 
56  // 'before' region contains the loop condition and forwarding of iteration
57  // arguments to the 'after' region.
58  auto *beforeBlock = rewriter.createBlock(
59  &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
60  rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61  arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
62  ? arith::CmpIPredicate::ult
63  : arith::CmpIPredicate::slt;
64  auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
65  beforeBlock->getArgument(0),
66  forOp.getUpperBound());
67  scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
68  beforeBlock->getArguments());
69 
70  // Inline for-loop body into an executeRegion operation in the "after"
71  // region. The return type of the execRegionOp does not contain the
72  // iv - yields in the source for-loop contain only iterArgs.
73  auto *afterBlock = rewriter.createBlock(
74  &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
75 
76  // Add induction variable incrementation
77  rewriter.setInsertionPointToEnd(afterBlock);
78  auto ivIncOp =
79  arith::AddIOp::create(rewriter, whileOp.getLoc(),
80  afterBlock->getArgument(0), forOp.getStep());
81 
82  // Rewrite uses of the for-loop block arguments to the new while-loop
83  // "after" arguments
84  for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
85  rewriter.replaceAllUsesWith(barg.value(),
86  afterBlock->getArgument(barg.index()));
87 
88  // Inline for-loop body operations into 'after' region.
89  for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
90  rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
91 
92  // Add incremented IV to yield operations
93  for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
94  SmallVector<Value> yieldOperands = yieldOp.getOperands();
95  yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
96  rewriter.modifyOpInPlace(yieldOp,
97  [&]() { yieldOp->setOperands(yieldOperands); });
98  }
99 
100  // We cannot do a direct replacement of the forOp since the while op returns
101  // an extra value (the induction variable escapes the loop through being
102  // carried in the set of iterargs). Instead, rewrite uses of the forOp
103  // results.
104  for (const auto &arg : llvm::enumerate(forOp.getResults()))
105  rewriter.replaceAllUsesWith(arg.value(),
106  whileOp.getResult(arg.index() + 1));
107 
108  rewriter.eraseOp(forOp);
109  return success();
110  }
111 };
112 
113 struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
114  void runOnOperation() override {
115  auto *parentOp = getOperation();
116  MLIRContext *ctx = parentOp->getContext();
118  patterns.add<ForLoopLoweringPattern>(ctx);
119  (void)applyPatternsGreedily(parentOp, std::move(patterns));
120  }
121 };
122 } // namespace
123 
124 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
125  return std::make_unique<ForToWhileLoop>();
126 }
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createForToWhileLoopPass()
Definition: ForToWhile.cpp:124
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314