MLIR  19.0.0git
ForToWhile.cpp
Go to the documentation of this file.
1 //===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.ForOp's into SCF.WhileOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/PatternMatch.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_SCFFORTOWHILELOOP
23 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace llvm;
27 using namespace mlir;
28 using scf::ForOp;
29 using scf::WhileOp;
30 
31 namespace {
32 
33 struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
35 
36  LogicalResult matchAndRewrite(ForOp forOp,
37  PatternRewriter &rewriter) const override {
38  // Generate type signature for the loop-carried values. The induction
39  // variable is placed first, followed by the forOp.iterArgs.
40  SmallVector<Type> lcvTypes;
41  SmallVector<Location> lcvLocs;
42  lcvTypes.push_back(forOp.getInductionVar().getType());
43  lcvLocs.push_back(forOp.getInductionVar().getLoc());
44  for (Value value : forOp.getInitArgs()) {
45  lcvTypes.push_back(value.getType());
46  lcvLocs.push_back(value.getLoc());
47  }
48 
49  // Build scf.WhileOp
50  SmallVector<Value> initArgs;
51  initArgs.push_back(forOp.getLowerBound());
52  llvm::append_range(initArgs, forOp.getInitArgs());
53  auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
54  forOp->getAttrs());
55 
56  // 'before' region contains the loop condition and forwarding of iteration
57  // arguments to the 'after' region.
58  auto *beforeBlock = rewriter.createBlock(
59  &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
60  rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61  auto cmpOp = rewriter.create<arith::CmpIOp>(
62  whileOp.getLoc(), arith::CmpIPredicate::slt,
63  beforeBlock->getArgument(0), forOp.getUpperBound());
64  rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
65  beforeBlock->getArguments());
66 
67  // Inline for-loop body into an executeRegion operation in the "after"
68  // region. The return type of the execRegionOp does not contain the
69  // iv - yields in the source for-loop contain only iterArgs.
70  auto *afterBlock = rewriter.createBlock(
71  &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
72 
73  // Add induction variable incrementation
74  rewriter.setInsertionPointToEnd(afterBlock);
75  auto ivIncOp = rewriter.create<arith::AddIOp>(
76  whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
77 
78  // Rewrite uses of the for-loop block arguments to the new while-loop
79  // "after" arguments
80  for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
81  rewriter.replaceAllUsesWith(barg.value(),
82  afterBlock->getArgument(barg.index()));
83 
84  // Inline for-loop body operations into 'after' region.
85  for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
86  rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
87 
88  // Add incremented IV to yield operations
89  for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
90  SmallVector<Value> yieldOperands = yieldOp.getOperands();
91  yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
92  rewriter.modifyOpInPlace(yieldOp,
93  [&]() { yieldOp->setOperands(yieldOperands); });
94  }
95 
96  // We cannot do a direct replacement of the forOp since the while op returns
97  // an extra value (the induction variable escapes the loop through being
98  // carried in the set of iterargs). Instead, rewrite uses of the forOp
99  // results.
100  for (const auto &arg : llvm::enumerate(forOp.getResults()))
101  rewriter.replaceAllUsesWith(arg.value(),
102  whileOp.getResult(arg.index() + 1));
103 
104  rewriter.eraseOp(forOp);
105  return success();
106  }
107 };
108 
109 struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
110  void runOnOperation() override {
111  auto *parentOp = getOperation();
112  MLIRContext *ctx = parentOp->getContext();
113  RewritePatternSet patterns(ctx);
114  patterns.add<ForLoopLoweringPattern>(ctx);
115  (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
116  }
117 };
118 } // namespace
119 
120 std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
121  return std::make_unique<ForToWhileLoop>();
122 }
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createForToWhileLoopPass()
Definition: ForToWhile.cpp:120
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358