MLIR  22.0.0git
LoopInvariantCodeMotionUtils.cpp
Go to the documentation of this file.
1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/DebugLog.h"
23 #include <queue>
24 
25 #define DEBUG_TYPE "licm"
26 
27 using namespace mlir;
28 
29 /// Checks whether the given op can be hoisted by checking that
30 /// - the op and none of its contained operations depend on values inside of the
31 /// loop (by means of calling definedOutside).
32 /// - the op has no side-effects.
33 static bool canBeHoisted(Operation *op,
34  function_ref<bool(OpOperand &)> condition) {
35  // Do not move terminators.
37  return false;
38 
39  // Walk the nested operations and check that all used values are either
40  // defined outside of the loop or in a nested region, but not at the level of
41  // the loop body.
42  auto walkFn = [&](Operation *child) {
43  for (OpOperand &operand : child->getOpOperands()) {
44  // Ignore values defined in a nested region.
45  if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
46  continue;
47  if (!condition(operand))
48  return WalkResult::interrupt();
49  }
50  return WalkResult::advance();
51  };
52  return !op->walk(walkFn).wasInterrupted();
53 }
54 
55 static bool canBeHoisted(Operation *op,
56  function_ref<bool(Value)> definedOutside) {
57  return canBeHoisted(
58  op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
59 }
60 
62  ArrayRef<Region *> regions,
63  function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
64  function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
65  function_ref<void(Operation *, Region *)> moveOutOfRegion) {
66  size_t numMoved = 0;
67 
68  for (Region *region : regions) {
69  LDBG() << "Original loop:\n" << *region->getParentOp();
70 
71  std::queue<Operation *> worklist;
72  // Add top-level operations in the loop body to the worklist.
73  for (Operation &op : region->getOps())
74  worklist.push(&op);
75 
76  auto definedOutside = [&](Value value) {
77  return isDefinedOutsideRegion(value, region);
78  };
79 
80  while (!worklist.empty()) {
81  Operation *op = worklist.front();
82  worklist.pop();
83  // Skip ops that have already been moved. Check if the op can be hoisted.
84  if (op->getParentRegion() != region)
85  continue;
86 
87  LDBG() << "Checking op: "
88  << OpWithFlags(op, OpPrintingFlags().skipRegions());
89  if (!shouldMoveOutOfRegion(op, region) ||
90  !canBeHoisted(op, definedOutside))
91  continue;
92 
93  LDBG() << "Moving loop-invariant op: " << *op;
94  moveOutOfRegion(op, region);
95  ++numMoved;
96 
97  // Since the op has been moved, we need to check its users within the
98  // top-level of the loop body.
99  for (Operation *user : op->getUsers())
100  if (user->getParentRegion() == region)
101  worklist.push(user);
102  }
103  }
104 
105  return numMoved;
106 }
107 
108 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
109  return moveLoopInvariantCode(
110  loopLike.getLoopRegions(),
111  [&](Value value, Region *) {
112  return loopLike.isDefinedOutsideOfLoop(value);
113  },
114  [&](Operation *op, Region *) {
115  return isMemoryEffectFree(op) && isSpeculatable(op);
116  },
117  [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
118 }
119 
120 namespace {
121 /// Helper data structure that keeps track of equivalent/disjoint subset ops.
122 class MatchingSubsets {
123 public:
124  /// Insert a subset op.
125  void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
126  allSubsetOps.push_back(op);
127  if (!collectHoistableOps)
128  return;
129  if (auto extractionOp =
130  dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
131  insertExtractionOp(extractionOp);
132  if (auto insertionOp =
133  dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
134  insertInsertionOp(insertionOp);
135  }
136 
137  /// Return a range of matching extraction-insertion subset ops. If there is no
138  /// matching extraction/insertion op, the respective value is empty. Ops are
139  /// skipped if there are other subset ops that are not guaranteed to operate
140  /// on disjoint subsets.
141  auto getHoistableSubsetOps() {
142  return llvm::make_filter_range(
143  llvm::zip(extractions, insertions), [&](auto pair) {
144  auto [extractionOp, insertionOp] = pair;
145  // Hoist only if the extracted and inserted values have the same type.
146  if (extractionOp && insertionOp &&
147  extractionOp->getResult(0).getType() !=
148  insertionOp.getSourceOperand().get().getType())
149  return false;
150  // Hoist only if there are no conflicting subset ops.
151  return allDisjoint(extractionOp, insertionOp);
152  });
153  }
154 
155  /// Populate subset ops starting from the given region iter_arg. Return
156  /// "failure" if non-subset ops are found along the path to the loop yielding
157  /// op or if there is no single path to the tied yielded operand. If
158  /// `collectHoistableOps` is set to "false", subset ops are gathered
159  /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
160  LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
161  BlockArgument iterArg,
162  bool collectHoistableOps = true);
163 
164 private:
165  /// Helper function for equivalence of tensor values. Since only insertion
166  /// subset ops (that are also destination style ops) are followed when
167  /// traversing the SSA use-def chain, all tensor values are equivalent.
168  static bool isEquivalent(Value v1, Value v2) { return true; }
169 
170  /// Return "true" if the subsets of the given extraction and insertion ops
171  /// are operating disjoint from the subsets that all other known subset ops
172  /// are operating on.
173  bool allDisjoint(SubsetExtractionOpInterface extractionOp,
174  SubsetInsertionOpInterface insertionOp) const {
175  for (SubsetOpInterface other : allSubsetOps) {
176  if (other == extractionOp || other == insertionOp)
177  continue;
178  if (extractionOp &&
179  !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
180  return false;
181  if (insertionOp &&
182  !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
183  return false;
184  }
185  return true;
186  }
187 
188  /// Insert a subset extraction op. If the subset is equivalent to an existing
189  /// subset insertion op, pair them up. (If there is already a paired up subset
190  /// extraction op, overwrite the subset extraction op.)
191  void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
192  for (auto it : llvm::enumerate(insertions)) {
193  if (!it.value())
194  continue;
195  auto other = cast<SubsetOpInterface>(it.value().getOperation());
196  if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
197  extractions[it.index()] = extractionOp;
198  return;
199  }
200  }
201  // There is no known equivalent insertion op. Create a new entry.
202  extractions.push_back(extractionOp);
203  insertions.push_back({});
204  }
205 
206  /// Insert a subset insertion op. If the subset is equivalent to an existing
207  /// subset extraction op, pair them up. (If there is already a paired up
208  /// subset insertion op, overwrite the subset insertion op.)
209  void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
210  for (auto it : llvm::enumerate(extractions)) {
211  if (!it.value())
212  continue;
213  auto other = cast<SubsetOpInterface>(it.value().getOperation());
214  if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
215  insertions[it.index()] = insertionOp;
216  return;
217  }
218  }
219  // There is no known equivalent extraction op. Create a new entry.
220  extractions.push_back({});
221  insertions.push_back(insertionOp);
222  }
223 
226  SmallVector<SubsetOpInterface> allSubsetOps;
227 };
228 } // namespace
229 
230 /// If the given value has a single use by an op that is a terminator, return
231 /// that use. Otherwise, return nullptr.
233  if (!value.hasOneUse())
234  return nullptr;
235  OpOperand &use = *value.getUses().begin();
237  return &use;
238  return nullptr;
239 }
240 
241 LogicalResult
242 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
243  BlockArgument iterArg,
244  bool collectHoistableOps) {
245  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
246  Value value = iterArg;
247 
248  // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
249  // use-def chain starting from the region iter_arg are subset extraction or
250  // subset insertion ops. The chain must terminate at the corresponding yield
251  // operand (e.g., no swapping of iter_args).
252  OpOperand *yieldedOperand = nullptr;
253  // Iterate until the single use of the current SSA value is a terminator,
254  // which is expected to be the yielding operation of the loop.
255  while (!(yieldedOperand = getSingleTerminatorUse(value))) {
256  Value nextValue = {};
257 
258  for (OpOperand &use : value.getUses()) {
259  if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
260  // Subset ops in nested loops are collected to check if there are only
261  // disjoint subset ops, but such subset ops are not subject to hoisting.
262  // To hoist subset ops from nested loops, the hoisting transformation
263  // should be run on the nested loop.
264  auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
265  if (!nestedIterArg)
266  return failure();
267  // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
268  // use-def chain starting at `nestedIterArg` and terminating in the
269  // tied, yielding operand.
270  if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
271  /*collectHoistableOps=*/false)))
272  return failure();
273  nextValue = nestedLoop.getTiedLoopResult(&use);
274  continue;
275  }
276 
277  auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
278  if (!subsetOp)
279  return failure();
280  insert(subsetOp);
281 
282  if (auto insertionOp =
283  dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
284  // Current implementation expects that the insertionOp implement
285  // the DestinationStyleOpInterface and with pure tensor semantics
286  // as well. Abort if that is not the case.
287  auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner());
288  if (!dstOp || !dstOp.hasPureTensorSemantics())
289  return failure();
290 
291  // The value must be used as a destination. (In case of a source, the
292  // entire tensor would be read, which would prevent any hoisting.)
293  if (&use != &insertionOp.getDestinationOperand())
294  return failure();
295  // There must be a single use-def chain from the region iter_arg to the
296  // terminator. I.e., only one insertion op. Branches are not supported.
297  if (nextValue)
298  return failure();
299  nextValue = insertionOp.getUpdatedDestination();
300  }
301  }
302 
303  // Nothing can be hoisted if the chain does not continue with loop yielding
304  // op or a subset insertion op.
305  if (!nextValue)
306  return failure();
307  value = nextValue;
308  }
309 
310  // Hoist only if the SSA use-def chain ends in the yielding terminator of the
311  // loop and the yielded value is the `idx`-th operand. (I.e., there is no
312  // swapping yield.)
313  if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
314  return failure();
315 
316  return success();
317 }
318 
319 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
320 /// loop-like op and index into loop-invariant subset locations. Return the
321 /// newly created loop op (that has extra iter_args) or the original loop op if
322 /// nothing was hoisted.
323 static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
324  LoopLikeOpInterface loopLike,
325  BlockArgument iterArg) {
326  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
327  BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
328  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
329  MatchingSubsets subsets;
330  if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
331  return loopLike;
332 
333  // Hoist all matching extraction-insertion pairs one-by-one.
334  for (auto it : subsets.getHoistableSubsetOps()) {
335  auto extractionOp = std::get<0>(it);
336  auto insertionOp = std::get<1>(it);
337 
338  // Ops cannot be hoisted if they depend on loop-variant values.
339  if (extractionOp) {
340  if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
341  return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
342  &operand == &extractionOp.getSourceOperand();
343  }))
344  extractionOp = {};
345  }
346  if (insertionOp) {
347  if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
348  return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
349  &operand == &insertionOp.getSourceOperand() ||
350  &operand == &insertionOp.getDestinationOperand();
351  }))
352  insertionOp = {};
353  }
354 
355  // Only hoist extraction-insertion pairs for now. Standalone extractions/
356  // insertions that are loop-invariant could be hoisted, but there may be
357  // easier ways to canonicalize the IR.
358  if (extractionOp && insertionOp) {
359  // Create a new loop with an additional iter_arg.
360  NewYieldValuesFn newYieldValuesFn =
361  [&](OpBuilder &b, Location loc,
362  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
363  return {insertionOp.getSourceOperand().get()};
364  };
365  FailureOr<LoopLikeOpInterface> newLoop =
366  loopLike.replaceWithAdditionalYields(
367  rewriter, extractionOp.getResult(),
368  /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
369  if (failed(newLoop))
370  return loopLike;
371  loopLike = *newLoop;
372 
373  // Hoist the extraction/insertion ops.
374  iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
375  OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
376  OpResult newLoopResult = loopLike.getLoopResults()->back();
377  rewriter.moveOpBefore(extractionOp, loopLike);
378  rewriter.moveOpAfter(insertionOp, loopLike);
379  rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
380  insertionOp.getDestinationOperand().get());
381  extractionOp.getSourceOperand().set(
382  loopLike.getTiedLoopInit(iterArg)->get());
383  rewriter.replaceAllUsesWith(loopResult,
384  insertionOp.getUpdatedDestination());
385  insertionOp.getSourceOperand().set(newLoopResult);
386  insertionOp.getDestinationOperand().set(loopResult);
387  }
388  }
389 
390  return loopLike;
391 }
392 
393 LoopLikeOpInterface
395  LoopLikeOpInterface loopLike) {
396  // Note: As subset ops are getting hoisted, the number of region iter_args
397  // increases. This can enable further hoisting opportunities on the new
398  // iter_args.
399  for (int64_t i = 0;
400  i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
401  loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
402  loopLike.getRegionIterArgs()[i]);
403  }
404  return loopLike;
405 }
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg)
Hoist all subset ops that operate on the idx-th region iter_arg of the given loop-like op and index i...
static OpOperand * getSingleTerminatorUse(Value value)
If the given value has a single use by an op that is a terminator, return that use.
static bool canBeHoisted(Operation *op, function_ref< bool(OpOperand &)> condition)
Checks whether the given op can be hoisted by checking that.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.