MLIR  22.0.0git
UpliftWhileToFor.cpp
Go to the documentation of this file.
1 //===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms SCF.WhileOp's into SCF.ForOp's.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
24 
25  LogicalResult matchAndRewrite(scf::WhileOp loop,
26  PatternRewriter &rewriter) const override {
27  return upliftWhileToForLoop(rewriter, loop);
28  }
29 };
30 } // namespace
31 
32 FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
33  scf::WhileOp loop) {
34  Block *beforeBody = loop.getBeforeBody();
35  if (!llvm::hasSingleElement(beforeBody->without_terminator()))
36  return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
37 
38  auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
39  if (!cmp)
40  return rewriter.notifyMatchFailure(loop,
41  "Loop body must have single cmp op");
42 
43  scf::ConditionOp beforeTerm = loop.getConditionOp();
44  if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
45  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
46  diag << "Expected single condition use: " << *cmp;
47  });
48 
49  // If all 'before' arguments are forwarded but the order is different from
50  // 'after' arguments, here is the mapping from the 'after' argument index to
51  // the 'before' argument index.
52  std::optional<SmallVector<unsigned>> argReorder;
53  // All `before` block args must be directly forwarded to ConditionOp.
54  // They will be converted to `scf.for` `iter_vars` except induction var.
55  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
56  auto getArgReordering =
57  [](Block *beforeBody,
58  scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
59  // Skip further checking if their sizes mismatch.
60  if (beforeBody->getNumArguments() != cond.getArgs().size())
61  return std::nullopt;
62  // Bitset on which 'before' argument is forwarded.
63  llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false);
64  // The forwarding order of 'before' arguments.
66  for (Value a : cond.getArgs()) {
67  BlockArgument arg = dyn_cast<BlockArgument>(a);
68  // Skip if 'arg' is not a 'before' argument.
69  if (!arg || arg.getOwner() != beforeBody)
70  return std::nullopt;
71  unsigned idx = arg.getArgNumber();
72  // Skip if 'arg' is already forwarded in another place.
73  if (forwarded[idx])
74  return std::nullopt;
75  // Record the presence of 'arg' and its order.
76  forwarded[idx] = true;
77  order.push_back(idx);
78  }
79  // Skip if not all 'before' arguments are forwarded.
80  if (!forwarded.all())
81  return std::nullopt;
82  return order;
83  };
84  // Check if 'before' arguments are all forwarded but just reordered.
85  argReorder = getArgReordering(beforeBody, beforeTerm);
86  if (!argReorder)
87  return rewriter.notifyMatchFailure(loop, "Invalid args order");
88  }
89 
90  using Pred = arith::CmpIPredicate;
91  Pred predicate = cmp.getPredicate();
92  if (predicate != Pred::slt && predicate != Pred::sgt)
93  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
94  diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
95  });
96 
97  BlockArgument inductionVar;
98  Value ub;
99  DominanceInfo dom;
100 
101  // Check if cmp has a suitable form. One of the arguments must be a `before`
102  // block arg, other must be defined outside `scf.while` and will be treated
103  // as upper bound.
104  for (bool reverse : {false, true}) {
105  auto expectedPred = reverse ? Pred::sgt : Pred::slt;
106  if (cmp.getPredicate() != expectedPred)
107  continue;
108 
109  auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
110  auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
111 
112  auto blockArg = dyn_cast<BlockArgument>(arg1);
113  if (!blockArg || blockArg.getOwner() != beforeBody)
114  continue;
115 
116  if (!dom.properlyDominates(arg2, loop))
117  continue;
118 
119  inductionVar = blockArg;
120  ub = arg2;
121  break;
122  }
123 
124  if (!inductionVar)
125  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
126  diag << "Unrecognized cmp form: " << *cmp;
127  });
128 
129  // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
130  // arg.
131  if (!llvm::hasNItems(inductionVar.getUses(), 2))
132  return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
133  diag << "Unrecognized induction var: " << inductionVar;
134  });
135 
136  Block *afterBody = loop.getAfterBody();
137  scf::YieldOp afterTerm = loop.getYieldOp();
138  unsigned argNumber = inductionVar.getArgNumber();
139  Value afterTermIndArg = afterTerm.getResults()[argNumber];
140 
141  auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) {
142  return std::distance(indices.begin(),
143  llvm::find_if(indices, [beforeArgNo](unsigned n) {
144  return n == beforeArgNo;
145  }));
146  };
147  Value inductionVarAfter = afterBody->getArgument(
148  argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);
149 
150  // Find suitable `addi` op inside `after` block, one of the args must be an
151  // Induction var passed from `before` block and second arg must be defined
152  // outside of the loop and will be considered step value.
153  // TODO: Add `subi` support?
154  auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
155  if (!addOp)
156  return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
157 
158  Value step;
159  if (addOp.getLhs() == inductionVarAfter) {
160  step = addOp.getRhs();
161  } else if (addOp.getRhs() == inductionVarAfter) {
162  step = addOp.getLhs();
163  }
164 
165  if (!step || !dom.properlyDominates(step, loop))
166  return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
167 
168  Value lb = loop.getInits()[argNumber];
169 
170  assert(lb.getType().isIntOrIndex());
171  assert(lb.getType() == ub.getType());
172  assert(lb.getType() == step.getType());
173 
174  SmallVector<Value> newArgs;
175 
176  // Populate inits for new `scf.for`, skip induction var.
177  newArgs.reserve(loop.getInits().size());
178  for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
179  if (i == argNumber)
180  continue;
181 
182  newArgs.emplace_back(init);
183  }
184 
185  Location loc = loop.getLoc();
186 
187  // With `builder == nullptr`, ForOp::build will try to insert terminator at
188  // the end of newly created block and we don't want it. Provide empty
189  // dummy builder instead.
190  auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
191  auto newLoop =
192  scf::ForOp::create(rewriter, loc, lb, ub, step, newArgs, emptyBuilder);
193 
194  Block *newBody = newLoop.getBody();
195 
196  // Populate block args for `scf.for` body, move induction var to the front.
197  newArgs.clear();
198  ValueRange newBodyArgs = newBody->getArguments();
199  for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
200  if (i < argNumber) {
201  newArgs.emplace_back(newBodyArgs[i + 1]);
202  } else if (i == argNumber) {
203  newArgs.emplace_back(newBodyArgs.front());
204  } else {
205  newArgs.emplace_back(newBodyArgs[i]);
206  }
207  }
208  if (argReorder) {
209  // Reorder arguments following the 'after' argument order from the original
210  // 'while' loop.
211  SmallVector<Value> args;
212  for (unsigned order : *argReorder)
213  args.push_back(newArgs[order]);
214  newArgs = args;
215  }
216 
217  rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
218  newArgs);
219 
220  auto term = cast<scf::YieldOp>(newBody->getTerminator());
221 
222  // Populate new yield args, skipping the induction var.
223  newArgs.clear();
224  for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
225  if (i == argNumber)
226  continue;
227 
228  newArgs.emplace_back(arg);
229  }
230 
231  OpBuilder::InsertionGuard g(rewriter);
232  rewriter.setInsertionPoint(term);
233  rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
234 
235  // Compute induction var value after loop execution.
236  rewriter.setInsertionPointAfter(newLoop);
237  Value one;
238  if (isa<IndexType>(step.getType())) {
239  one = arith::ConstantIndexOp::create(rewriter, loc, 1);
240  } else {
241  one = arith::ConstantIntOp::create(rewriter, loc, step.getType(), 1);
242  }
243 
244  Value stepDec = arith::SubIOp::create(rewriter, loc, step, one);
245  Value len = arith::SubIOp::create(rewriter, loc, ub, lb);
246  len = arith::AddIOp::create(rewriter, loc, len, stepDec);
247  len = arith::DivSIOp::create(rewriter, loc, len, step);
248  len = arith::SubIOp::create(rewriter, loc, len, one);
249  Value res = arith::MulIOp::create(rewriter, loc, len, step);
250  res = arith::AddIOp::create(rewriter, loc, lb, res);
251 
252  // Reconstruct `scf.while` results, inserting final induction var value
253  // into proper place.
254  newArgs.clear();
255  llvm::append_range(newArgs, newLoop.getResults());
256  newArgs.insert(newArgs.begin() + argNumber, res);
257  if (argReorder) {
258  // Reorder arguments following the 'after' argument order from the original
259  // 'while' loop.
260  SmallVector<Value> results;
261  for (unsigned order : *argReorder)
262  results.push_back(newArgs[order]);
263  newArgs = results;
264  }
265  rewriter.replaceOp(loop, newArgs);
266  return newLoop;
267 }
268 
270  patterns.add<UpliftWhileOp>(patterns.getContext());
271 }
static std::string diag(const llvm::Value &value)
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< ForOp > upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop)
Try to uplift scf.while op to scf.for.
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns)
Populate patterns to uplift scf.while ops to scf.for.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319