MLIR  19.0.0git
UpliftWhileToFor.cpp
Go to the documentation of this file.
1 //===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.WhileOp's into SCF.ForOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/Dominance.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
26 
27  LogicalResult matchAndRewrite(scf::WhileOp loop,
28  PatternRewriter &rewriter) const override {
29  return upliftWhileToForLoop(rewriter, loop);
30  }
31 };
32 } // namespace
33 
35  scf::WhileOp loop) {
36  Block *beforeBody = loop.getBeforeBody();
37  if (!llvm::hasSingleElement(beforeBody->without_terminator()))
38  return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
39 
40  auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
41  if (!cmp)
42  return rewriter.notifyMatchFailure(loop,
43  "Loop body must have single cmp op");
44 
45  scf::ConditionOp beforeTerm = loop.getConditionOp();
46  if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
47  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
48  diag << "Expected single condition use: " << *cmp;
49  });
50 
51  // All `before` block args must be directly forwarded to ConditionOp.
52  // They will be converted to `scf.for` `iter_vars` except induction var.
53  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54  return rewriter.notifyMatchFailure(loop, "Invalid args order");
55 
56  using Pred = arith::CmpIPredicate;
57  Pred predicate = cmp.getPredicate();
58  if (predicate != Pred::slt && predicate != Pred::sgt)
59  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
60  diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
61  });
62 
63  BlockArgument inductionVar;
64  Value ub;
65  DominanceInfo dom;
66 
67  // Check if cmp has a suitable form. One of the arguments must be a `before`
68  // block arg, other must be defined outside `scf.while` and will be treated
69  // as upper bound.
70  for (bool reverse : {false, true}) {
71  auto expectedPred = reverse ? Pred::sgt : Pred::slt;
72  if (cmp.getPredicate() != expectedPred)
73  continue;
74 
75  auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
76  auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
77 
78  auto blockArg = dyn_cast<BlockArgument>(arg1);
79  if (!blockArg || blockArg.getOwner() != beforeBody)
80  continue;
81 
82  if (!dom.properlyDominates(arg2, loop))
83  continue;
84 
85  inductionVar = blockArg;
86  ub = arg2;
87  break;
88  }
89 
90  if (!inductionVar)
91  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
92  diag << "Unrecognized cmp form: " << *cmp;
93  });
94 
95  // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
96  // arg.
97  if (!llvm::hasNItems(inductionVar.getUses(), 2))
98  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
99  diag << "Unrecognized induction var: " << inductionVar;
100  });
101 
102  Block *afterBody = loop.getAfterBody();
103  scf::YieldOp afterTerm = loop.getYieldOp();
104  unsigned argNumber = inductionVar.getArgNumber();
105  Value afterTermIndArg = afterTerm.getResults()[argNumber];
106 
107  Value inductionVarAfter = afterBody->getArgument(argNumber);
108 
109  // Find suitable `addi` op inside `after` block, one of the args must be an
110  // Induction var passed from `before` block and second arg must be defined
111  // outside of the loop and will be considered step value.
112  // TODO: Add `subi` support?
113  auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
114  if (!addOp)
115  return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
116 
117  Value step;
118  if (addOp.getLhs() == inductionVarAfter) {
119  step = addOp.getRhs();
120  } else if (addOp.getRhs() == inductionVarAfter) {
121  step = addOp.getLhs();
122  }
123 
124  if (!step || !dom.properlyDominates(step, loop))
125  return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
126 
127  Value lb = loop.getInits()[argNumber];
128 
129  assert(lb.getType().isIntOrIndex());
130  assert(lb.getType() == ub.getType());
131  assert(lb.getType() == step.getType());
132 
133  llvm::SmallVector<Value> newArgs;
134 
135  // Populate inits for new `scf.for`, skip induction var.
136  newArgs.reserve(loop.getInits().size());
137  for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
138  if (i == argNumber)
139  continue;
140 
141  newArgs.emplace_back(init);
142  }
143 
144  Location loc = loop.getLoc();
145 
146  // With `builder == nullptr`, ForOp::build will try to insert terminator at
147  // the end of newly created block and we don't want it. Provide empty
148  // dummy builder instead.
149  auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
150  auto newLoop =
151  rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
152 
153  Block *newBody = newLoop.getBody();
154 
155  // Populate block args for `scf.for` body, move induction var to the front.
156  newArgs.clear();
157  ValueRange newBodyArgs = newBody->getArguments();
158  for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
159  if (i < argNumber) {
160  newArgs.emplace_back(newBodyArgs[i + 1]);
161  } else if (i == argNumber) {
162  newArgs.emplace_back(newBodyArgs.front());
163  } else {
164  newArgs.emplace_back(newBodyArgs[i]);
165  }
166  }
167 
168  rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
169  newArgs);
170 
171  auto term = cast<scf::YieldOp>(newBody->getTerminator());
172 
173  // Populate new yield args, skipping the induction var.
174  newArgs.clear();
175  for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
176  if (i == argNumber)
177  continue;
178 
179  newArgs.emplace_back(arg);
180  }
181 
182  OpBuilder::InsertionGuard g(rewriter);
183  rewriter.setInsertionPoint(term);
184  rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
185 
186  // Compute induction var value after loop execution.
187  rewriter.setInsertionPointAfter(newLoop);
188  Value one;
189  if (isa<IndexType>(step.getType())) {
190  one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
191  } else {
192  one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
193  }
194 
195  Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
196  Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
197  len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
198  len = rewriter.create<arith::DivSIOp>(loc, len, step);
199  len = rewriter.create<arith::SubIOp>(loc, len, one);
200  Value res = rewriter.create<arith::MulIOp>(loc, len, step);
201  res = rewriter.create<arith::AddIOp>(loc, lb, res);
202 
203  // Reconstruct `scf.while` results, inserting final induction var value
204  // into proper place.
205  newArgs.clear();
206  llvm::append_range(newArgs, newLoop.getResults());
207  newArgs.insert(newArgs.begin() + argNumber, res);
208  rewriter.replaceOp(loop, newArgs);
209  return newLoop;
210 }
211 
213  patterns.add<UpliftWhileOp>(patterns.getContext());
214 }
static std::string diag(const llvm::Value &value)
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
void clear()
Definition: Block.h:35
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:115
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< ForOp > upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop)
Try to uplift scf.while op to scf.for.
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns)
Populate patterns to uplift scf.while ops to scf.for.
Include the generated interface declarations.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362