MLIR  22.0.0git
LoopInvariantCodeMotionUtils.cpp
Go to the documentation of this file.
1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/DebugLog.h"
23 #include <queue>
24 
25 #define DEBUG_TYPE "licm"
26 
27 using namespace mlir;
28 
29 /// Checks whether the given op can be hoisted by checking that
30 /// - the op and none of its contained operations depend on values inside of the
31 /// loop (by means of calling definedOutside).
32 /// - the op has no side-effects.
33 static bool canBeHoisted(Operation *op,
34  function_ref<bool(OpOperand &)> condition) {
35  // Do not move terminators.
37  return false;
38 
39  // Walk the nested operations and check that all used values are either
40  // defined outside of the loop or in a nested region, but not at the level of
41  // the loop body.
42  auto walkFn = [&](Operation *child) {
43  for (OpOperand &operand : child->getOpOperands()) {
44  // Ignore values defined in a nested region.
45  if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
46  continue;
47  if (!condition(operand))
48  return WalkResult::interrupt();
49  }
50  return WalkResult::advance();
51  };
52  return !op->walk(walkFn).wasInterrupted();
53 }
54 
55 static bool canBeHoisted(Operation *op,
56  function_ref<bool(Value)> definedOutside) {
57  return canBeHoisted(
58  op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
59 }
60 
62  ArrayRef<Region *> regions,
63  function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
64  function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
65  function_ref<void(Operation *, Region *)> moveOutOfRegion) {
66  size_t numMoved = 0;
67 
68  for (Region *region : regions) {
69  LDBG() << "Original loop:\n"
70  << OpWithFlags(region->getParentOp(),
71  OpPrintingFlags().skipRegions());
72 
73  std::queue<Operation *> worklist;
74  // Add top-level operations in the loop body to the worklist.
75  for (Operation &op : region->getOps())
76  worklist.push(&op);
77 
78  auto definedOutside = [&](Value value) {
79  return isDefinedOutsideRegion(value, region);
80  };
81 
82  while (!worklist.empty()) {
83  Operation *op = worklist.front();
84  worklist.pop();
85  // Skip ops that have already been moved. Check if the op can be hoisted.
86  if (op->getParentRegion() != region)
87  continue;
88 
89  LDBG() << "Checking op: "
90  << OpWithFlags(op, OpPrintingFlags().skipRegions());
91  if (!shouldMoveOutOfRegion(op, region) ||
92  !canBeHoisted(op, definedOutside))
93  continue;
94 
95  LDBG() << "Moving loop-invariant op: "
96  << OpWithFlags(op, OpPrintingFlags().skipRegions());
97  moveOutOfRegion(op, region);
98  ++numMoved;
99 
100  // Since the op has been moved, we need to check its users within the
101  // top-level of the loop body.
102  for (Operation *user : op->getUsers())
103  if (user->getParentRegion() == region)
104  worklist.push(user);
105  }
106  }
107 
108  return numMoved;
109 }
110 
111 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
112  return moveLoopInvariantCode(
113  loopLike.getLoopRegions(),
114  [&](Value value, Region *) {
115  return loopLike.isDefinedOutsideOfLoop(value);
116  },
117  [&](Operation *op, Region *) { return isPure(op); },
118  [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
119 }
120 
121 namespace {
122 /// Helper data structure that keeps track of equivalent/disjoint subset ops.
123 class MatchingSubsets {
124 public:
125  /// Insert a subset op.
126  void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
127  allSubsetOps.push_back(op);
128  if (!collectHoistableOps)
129  return;
130  if (auto extractionOp =
131  dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
132  insertExtractionOp(extractionOp);
133  if (auto insertionOp =
134  dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
135  insertInsertionOp(insertionOp);
136  }
137 
138  /// Return a range of matching extraction-insertion subset ops. If there is no
139  /// matching extraction/insertion op, the respective value is empty. Ops are
140  /// skipped if there are other subset ops that are not guaranteed to operate
141  /// on disjoint subsets.
142  auto getHoistableSubsetOps() {
143  return llvm::make_filter_range(
144  llvm::zip(extractions, insertions), [&](auto pair) {
145  auto [extractionOp, insertionOp] = pair;
146  // Hoist only if the extracted and inserted values have the same type.
147  if (extractionOp && insertionOp &&
148  extractionOp->getResult(0).getType() !=
149  insertionOp.getSourceOperand().get().getType())
150  return false;
151  // Hoist only if there are no conflicting subset ops.
152  return allDisjoint(extractionOp, insertionOp);
153  });
154  }
155 
156  /// Populate subset ops starting from the given region iter_arg. Return
157  /// "failure" if non-subset ops are found along the path to the loop yielding
158  /// op or if there is no single path to the tied yielded operand. If
159  /// `collectHoistableOps` is set to "false", subset ops are gathered
160  /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
161  LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
162  BlockArgument iterArg,
163  bool collectHoistableOps = true);
164 
165 private:
166  /// Helper function for equivalence of tensor values. Since only insertion
167  /// subset ops (that are also destination style ops) are followed when
168  /// traversing the SSA use-def chain, all tensor values are equivalent.
169  static bool isEquivalent(Value v1, Value v2) { return true; }
170 
171  /// Return "true" if the subsets of the given extraction and insertion ops
172  /// are operating disjoint from the subsets that all other known subset ops
173  /// are operating on.
174  bool allDisjoint(SubsetExtractionOpInterface extractionOp,
175  SubsetInsertionOpInterface insertionOp) const {
176  for (SubsetOpInterface other : allSubsetOps) {
177  if (other == extractionOp || other == insertionOp)
178  continue;
179  if (extractionOp &&
180  !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
181  return false;
182  if (insertionOp &&
183  !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
184  return false;
185  }
186  return true;
187  }
188 
189  /// Insert a subset extraction op. If the subset is equivalent to an existing
190  /// subset insertion op, pair them up. (If there is already a paired up subset
191  /// extraction op, overwrite the subset extraction op.)
192  void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
193  for (auto it : llvm::enumerate(insertions)) {
194  if (!it.value())
195  continue;
196  auto other = cast<SubsetOpInterface>(it.value().getOperation());
197  if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
198  extractions[it.index()] = extractionOp;
199  return;
200  }
201  }
202  // There is no known equivalent insertion op. Create a new entry.
203  extractions.push_back(extractionOp);
204  insertions.push_back({});
205  }
206 
207  /// Insert a subset insertion op. If the subset is equivalent to an existing
208  /// subset extraction op, pair them up. (If there is already a paired up
209  /// subset insertion op, overwrite the subset insertion op.)
210  void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
211  for (auto it : llvm::enumerate(extractions)) {
212  if (!it.value())
213  continue;
214  auto other = cast<SubsetOpInterface>(it.value().getOperation());
215  if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
216  insertions[it.index()] = insertionOp;
217  return;
218  }
219  }
220  // There is no known equivalent extraction op. Create a new entry.
221  extractions.push_back({});
222  insertions.push_back(insertionOp);
223  }
224 
227  SmallVector<SubsetOpInterface> allSubsetOps;
228 };
229 } // namespace
230 
231 /// If the given value has a single use by an op that is a terminator, return
232 /// that use. Otherwise, return nullptr.
234  if (!value.hasOneUse())
235  return nullptr;
236  OpOperand &use = *value.getUses().begin();
238  return &use;
239  return nullptr;
240 }
241 
242 LogicalResult
243 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
244  BlockArgument iterArg,
245  bool collectHoistableOps) {
246  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
247  Value value = iterArg;
248 
249  // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
250  // use-def chain starting from the region iter_arg are subset extraction or
251  // subset insertion ops. The chain must terminate at the corresponding yield
252  // operand (e.g., no swapping of iter_args).
253  OpOperand *yieldedOperand = nullptr;
254  // Iterate until the single use of the current SSA value is a terminator,
255  // which is expected to be the yielding operation of the loop.
256  while (!(yieldedOperand = getSingleTerminatorUse(value))) {
257  Value nextValue = {};
258 
259  for (OpOperand &use : value.getUses()) {
260  if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
261  // Subset ops in nested loops are collected to check if there are only
262  // disjoint subset ops, but such subset ops are not subject to hoisting.
263  // To hoist subset ops from nested loops, the hoisting transformation
264  // should be run on the nested loop.
265  auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
266  if (!nestedIterArg)
267  return failure();
268  // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
269  // use-def chain starting at `nestedIterArg` and terminating in the
270  // tied, yielding operand.
271  if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
272  /*collectHoistableOps=*/false)))
273  return failure();
274  nextValue = nestedLoop.getTiedLoopResult(&use);
275  continue;
276  }
277 
278  auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
279  if (!subsetOp)
280  return failure();
281  insert(subsetOp);
282 
283  if (auto insertionOp =
284  dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
285  // Current implementation expects that the insertionOp implement
286  // the DestinationStyleOpInterface and with pure tensor semantics
287  // as well. Abort if that is not the case.
288  auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner());
289  if (!dstOp || !dstOp.hasPureTensorSemantics())
290  return failure();
291 
292  // The value must be used as a destination. (In case of a source, the
293  // entire tensor would be read, which would prevent any hoisting.)
294  if (&use != &insertionOp.getDestinationOperand())
295  return failure();
296  // There must be a single use-def chain from the region iter_arg to the
297  // terminator. I.e., only one insertion op. Branches are not supported.
298  if (nextValue)
299  return failure();
300  nextValue = insertionOp.getUpdatedDestination();
301  }
302  }
303 
304  // Nothing can be hoisted if the chain does not continue with loop yielding
305  // op or a subset insertion op.
306  if (!nextValue)
307  return failure();
308  value = nextValue;
309  }
310 
311  // Hoist only if the SSA use-def chain ends in the yielding terminator of the
312  // loop and the yielded value is the `idx`-th operand. (I.e., there is no
313  // swapping yield.)
314  if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
315  return failure();
316 
317  return success();
318 }
319 
320 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
321 /// loop-like op and index into loop-invariant subset locations. Return the
322 /// newly created loop op (that has extra iter_args) or the original loop op if
323 /// nothing was hoisted.
324 static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
325  LoopLikeOpInterface loopLike,
326  BlockArgument iterArg) {
327  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
328  BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
329  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
330  MatchingSubsets subsets;
331  if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
332  return loopLike;
333 
334  // Hoist all matching extraction-insertion pairs one-by-one.
335  for (auto it : subsets.getHoistableSubsetOps()) {
336  auto extractionOp = std::get<0>(it);
337  auto insertionOp = std::get<1>(it);
338 
339  // Ops cannot be hoisted if they depend on loop-variant values.
340  if (extractionOp) {
341  if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
342  return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
343  &operand == &extractionOp.getSourceOperand();
344  }))
345  extractionOp = {};
346  }
347  if (insertionOp) {
348  if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
349  return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
350  &operand == &insertionOp.getSourceOperand() ||
351  &operand == &insertionOp.getDestinationOperand();
352  }))
353  insertionOp = {};
354  }
355 
356  // Only hoist extraction-insertion pairs for now. Standalone extractions/
357  // insertions that are loop-invariant could be hoisted, but there may be
358  // easier ways to canonicalize the IR.
359  if (extractionOp && insertionOp) {
360  // Create a new loop with an additional iter_arg.
361  NewYieldValuesFn newYieldValuesFn =
362  [&](OpBuilder &b, Location loc,
363  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
364  return {insertionOp.getSourceOperand().get()};
365  };
366  FailureOr<LoopLikeOpInterface> newLoop =
367  loopLike.replaceWithAdditionalYields(
368  rewriter, extractionOp.getResult(),
369  /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
370  if (failed(newLoop))
371  return loopLike;
372  loopLike = *newLoop;
373 
374  // Hoist the extraction/insertion ops.
375  iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
376  OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
377  OpResult newLoopResult = loopLike.getLoopResults()->back();
378  rewriter.moveOpBefore(extractionOp, loopLike);
379  rewriter.moveOpAfter(insertionOp, loopLike);
380  rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
381  insertionOp.getDestinationOperand().get());
382  extractionOp.getSourceOperand().set(
383  loopLike.getTiedLoopInit(iterArg)->get());
384  rewriter.replaceAllUsesWith(loopResult,
385  insertionOp.getUpdatedDestination());
386  insertionOp.getSourceOperand().set(newLoopResult);
387  insertionOp.getDestinationOperand().set(loopResult);
388  }
389  }
390 
391  return loopLike;
392 }
393 
394 LoopLikeOpInterface
396  LoopLikeOpInterface loopLike) {
397  // Note: As subset ops are getting hoisted, the number of region iter_args
398  // increases. This can enable further hoisting opportunities on the new
399  // iter_args.
400  for (int64_t i = 0;
401  i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
402  loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
403  loopLike.getRegionIterArgs()[i]);
404  }
405  return loopLike;
406 }
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg)
Hoist all subset ops that operate on the idx-th region iter_arg of the given loop-like op and index i...
static OpOperand * getSingleTerminatorUse(Value value)
If the given value has a single use by an op that is a terminator, return that use.
static bool canBeHoisted(Operation *op, function_ref< bool(OpOperand &)> condition)
Checks whether the given op can be hoisted by checking that.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.