MLIR  20.0.0git
LoopInvariantCodeMotionUtils.cpp
Go to the documentation of this file.
1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 
23 #define DEBUG_TYPE "licm"
24 
25 using namespace mlir;
26 
27 /// Checks whether the given op can be hoisted by checking that
28 /// - the op and none of its contained operations depend on values inside of the
29 /// loop (by means of calling definedOutside).
30 /// - the op has no side-effects.
31 static bool canBeHoisted(Operation *op,
32  function_ref<bool(OpOperand &)> condition) {
33  // Do not move terminators.
35  return false;
36 
37  // Walk the nested operations and check that all used values are either
38  // defined outside of the loop or in a nested region, but not at the level of
39  // the loop body.
40  auto walkFn = [&](Operation *child) {
41  for (OpOperand &operand : child->getOpOperands()) {
42  // Ignore values defined in a nested region.
43  if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
44  continue;
45  if (!condition(operand))
46  return WalkResult::interrupt();
47  }
48  return WalkResult::advance();
49  };
50  return !op->walk(walkFn).wasInterrupted();
51 }
52 
53 static bool canBeHoisted(Operation *op,
54  function_ref<bool(Value)> definedOutside) {
55  return canBeHoisted(
56  op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
57 }
58 
60  ArrayRef<Region *> regions,
61  function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
62  function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63  function_ref<void(Operation *, Region *)> moveOutOfRegion) {
64  size_t numMoved = 0;
65 
66  for (Region *region : regions) {
67  LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
68  << *region->getParentOp() << "\n");
69 
70  std::queue<Operation *> worklist;
71  // Add top-level operations in the loop body to the worklist.
72  for (Operation &op : region->getOps())
73  worklist.push(&op);
74 
75  auto definedOutside = [&](Value value) {
76  return isDefinedOutsideRegion(value, region);
77  };
78 
79  while (!worklist.empty()) {
80  Operation *op = worklist.front();
81  worklist.pop();
82  // Skip ops that have already been moved. Check if the op can be hoisted.
83  if (op->getParentRegion() != region)
84  continue;
85 
86  LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
87  if (!shouldMoveOutOfRegion(op, region) ||
88  !canBeHoisted(op, definedOutside))
89  continue;
90 
91  LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92  moveOutOfRegion(op, region);
93  ++numMoved;
94 
95  // Since the op has been moved, we need to check its users within the
96  // top-level of the loop body.
97  for (Operation *user : op->getUsers())
98  if (user->getParentRegion() == region)
99  worklist.push(user);
100  }
101  }
102 
103  return numMoved;
104 }
105 
106 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107  return moveLoopInvariantCode(
108  loopLike.getLoopRegions(),
109  [&](Value value, Region *) {
110  return loopLike.isDefinedOutsideOfLoop(value);
111  },
112  [&](Operation *op, Region *) {
113  return isMemoryEffectFree(op) && isSpeculatable(op);
114  },
115  [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
116 }
117 
118 namespace {
119 /// Helper data structure that keeps track of equivalent/disjoint subset ops.
120 class MatchingSubsets {
121 public:
122  /// Insert a subset op.
123  void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
124  allSubsetOps.push_back(op);
125  if (!collectHoistableOps)
126  return;
127  if (auto extractionOp =
128  dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
129  insertExtractionOp(extractionOp);
130  if (auto insertionOp =
131  dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
132  insertInsertionOp(insertionOp);
133  }
134 
135  /// Return a range of matching extraction-insertion subset ops. If there is no
136  /// matching extraction/insertion op, the respective value is empty. Ops are
137  /// skipped if there are other subset ops that are not guaranteed to operate
138  /// on disjoint subsets.
139  auto getHoistableSubsetOps() {
140  return llvm::make_filter_range(
141  llvm::zip(extractions, insertions), [&](auto pair) {
142  auto [extractionOp, insertionOp] = pair;
143  // Hoist only if the extracted and inserted values have the same type.
144  if (extractionOp && insertionOp &&
145  extractionOp->getResult(0).getType() !=
146  insertionOp.getSourceOperand().get().getType())
147  return false;
148  // Hoist only if there are no conflicting subset ops.
149  return allDisjoint(extractionOp, insertionOp);
150  });
151  }
152 
153  /// Populate subset ops starting from the given region iter_arg. Return
154  /// "failure" if non-subset ops are found along the path to the loop yielding
155  /// op or if there is no single path to the tied yielded operand. If
156  /// `collectHoistableOps` is set to "false", subset ops are gathered
157  /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158  LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
159  BlockArgument iterArg,
160  bool collectHoistableOps = true);
161 
162 private:
163  /// Helper function for equivalence of tensor values. Since only insertion
164  /// subset ops (that are also destination style ops) are followed when
165  /// traversing the SSA use-def chain, all tensor values are equivalent.
166  static bool isEquivalent(Value v1, Value v2) { return true; }
167 
168  /// Return "true" if the subsets of the given extraction and insertion ops
169  /// are operating disjoint from the subsets that all other known subset ops
170  /// are operating on.
171  bool allDisjoint(SubsetExtractionOpInterface extractionOp,
172  SubsetInsertionOpInterface insertionOp) const {
173  for (SubsetOpInterface other : allSubsetOps) {
174  if (other == extractionOp || other == insertionOp)
175  continue;
176  if (extractionOp &&
177  !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
178  return false;
179  if (insertionOp &&
180  !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
181  return false;
182  }
183  return true;
184  }
185 
186  /// Insert a subset extraction op. If the subset is equivalent to an existing
187  /// subset insertion op, pair them up. (If there is already a paired up subset
188  /// extraction op, overwrite the subset extraction op.)
189  void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
190  for (auto it : llvm::enumerate(insertions)) {
191  if (!it.value())
192  continue;
193  auto other = cast<SubsetOpInterface>(it.value().getOperation());
194  if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
195  extractions[it.index()] = extractionOp;
196  return;
197  }
198  }
199  // There is no known equivalent insertion op. Create a new entry.
200  extractions.push_back(extractionOp);
201  insertions.push_back({});
202  }
203 
204  /// Insert a subset insertion op. If the subset is equivalent to an existing
205  /// subset extraction op, pair them up. (If there is already a paired up
206  /// subset insertion op, overwrite the subset insertion op.)
207  void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
208  for (auto it : llvm::enumerate(extractions)) {
209  if (!it.value())
210  continue;
211  auto other = cast<SubsetOpInterface>(it.value().getOperation());
212  if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
213  insertions[it.index()] = insertionOp;
214  return;
215  }
216  }
217  // There is no known equivalent extraction op. Create a new entry.
218  extractions.push_back({});
219  insertions.push_back(insertionOp);
220  }
221 
224  SmallVector<SubsetOpInterface> allSubsetOps;
225 };
226 } // namespace
227 
228 /// If the given value has a single use by an op that is a terminator, return
229 /// that use. Otherwise, return nullptr.
231  if (!value.hasOneUse())
232  return nullptr;
233  OpOperand &use = *value.getUses().begin();
235  return &use;
236  return nullptr;
237 }
238 
239 LogicalResult
240 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
241  BlockArgument iterArg,
242  bool collectHoistableOps) {
243  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
244  Value value = iterArg;
245 
246  // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
247  // use-def chain starting from the region iter_arg are subset extraction or
248  // subset insertion ops. The chain must terminate at the corresponding yield
249  // operand (e.g., no swapping of iter_args).
250  OpOperand *yieldedOperand = nullptr;
251  // Iterate until the single use of the current SSA value is a terminator,
252  // which is expected to be the yielding operation of the loop.
253  while (!(yieldedOperand = getSingleTerminatorUse(value))) {
254  Value nextValue = {};
255 
256  for (OpOperand &use : value.getUses()) {
257  if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
258  // Subset ops in nested loops are collected to check if there are only
259  // disjoint subset ops, but such subset ops are not subject to hoisting.
260  // To hoist subset ops from nested loops, the hoisting transformation
261  // should be run on the nested loop.
262  auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
263  if (!nestedIterArg)
264  return failure();
265  // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266  // use-def chain starting at `nestedIterArg` and terminating in the
267  // tied, yielding operand.
268  if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
269  /*collectHoistableOps=*/false)))
270  return failure();
271  nextValue = nestedLoop.getTiedLoopResult(&use);
272  continue;
273  }
274 
275  auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
276  if (!subsetOp)
277  return failure();
278  insert(subsetOp);
279 
280  if (auto insertionOp =
281  dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
282  // Current implementation expects that the insertionOp implement
283  // the destinationStyleOpInterface as well. Abort if that tha is not
284  // the case
285  if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
286  return failure();
287  }
288 
289  // The value must be used as a destination. (In case of a source, the
290  // entire tensor would be read, which would prevent any hoisting.)
291  if (&use != &insertionOp.getDestinationOperand())
292  return failure();
293  // There must be a single use-def chain from the region iter_arg to the
294  // terminator. I.e., only one insertion op. Branches are not supported.
295  if (nextValue)
296  return failure();
297  nextValue = insertionOp.getUpdatedDestination();
298  }
299  }
300 
301  // Nothing can be hoisted if the chain does not continue with loop yielding
302  // op or a subset insertion op.
303  if (!nextValue)
304  return failure();
305  value = nextValue;
306  }
307 
308  // Hoist only if the SSA use-def chain ends in the yielding terminator of the
309  // loop and the yielded value is the `idx`-th operand. (I.e., there is no
310  // swapping yield.)
311  if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
312  return failure();
313 
314  return success();
315 }
316 
317 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
318 /// loop-like op and index into loop-invariant subset locations. Return the
319 /// newly created loop op (that has extra iter_args) or the original loop op if
320 /// nothing was hoisted.
321 static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
322  LoopLikeOpInterface loopLike,
323  BlockArgument iterArg) {
324  assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
325  auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
326  int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
327  MatchingSubsets subsets;
328  if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
329  return loopLike;
330 
331  // Hoist all matching extraction-insertion pairs one-by-one.
332  for (auto it : subsets.getHoistableSubsetOps()) {
333  auto extractionOp = std::get<0>(it);
334  auto insertionOp = std::get<1>(it);
335 
336  // Ops cannot be hoisted if they depend on loop-variant values.
337  if (extractionOp) {
338  if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
339  return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
340  &operand == &extractionOp.getSourceOperand();
341  }))
342  extractionOp = {};
343  }
344  if (insertionOp) {
345  if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
346  return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
347  &operand == &insertionOp.getSourceOperand() ||
348  &operand == &insertionOp.getDestinationOperand();
349  }))
350  insertionOp = {};
351  }
352 
353  // Only hoist extraction-insertion pairs for now. Standalone extractions/
354  // insertions that are loop-invariant could be hoisted, but there may be
355  // easier ways to canonicalize the IR.
356  if (extractionOp && insertionOp) {
357  // Create a new loop with an additional iter_arg.
358  NewYieldValuesFn newYieldValuesFn =
359  [&](OpBuilder &b, Location loc,
360  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
361  return {insertionOp.getSourceOperand().get()};
362  };
363  FailureOr<LoopLikeOpInterface> newLoop =
364  loopLike.replaceWithAdditionalYields(
365  rewriter, extractionOp.getResult(),
366  /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
367  if (failed(newLoop))
368  return loopLike;
369  loopLike = *newLoop;
370 
371  // Hoist the extraction/insertion ops.
372  iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
373  OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
374  OpResult newLoopResult = loopLike.getLoopResults()->back();
375  rewriter.moveOpBefore(extractionOp, loopLike);
376  rewriter.moveOpAfter(insertionOp, loopLike);
377  rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
378  insertionOp.getDestinationOperand().get());
379  extractionOp.getSourceOperand().set(
380  loopLike.getTiedLoopInit(iterArg)->get());
381  rewriter.replaceAllUsesWith(loopResult,
382  insertionOp.getUpdatedDestination());
383  insertionOp.getSourceOperand().set(newLoopResult);
384  insertionOp.getDestinationOperand().set(loopResult);
385  }
386  }
387 
388  return loopLike;
389 }
390 
391 LoopLikeOpInterface
393  LoopLikeOpInterface loopLike) {
394  // Note: As subset ops are getting hoisted, the number of region iter_args
395  // increases. This can enable further hoisting opportunities on the new
396  // iter_args.
397  for (int64_t i = 0;
398  i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
399  loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
400  loopLike.getRegionIterArgs()[i]);
401  }
402  return loopLike;
403 }
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg)
Hoist all subset ops that operate on the idx-th region iter_arg of the given loop-like op and index i...
static OpOperand * getSingleTerminatorUse(Value value)
If the given value has a single use by an op that is a terminator, return that use.
static bool canBeHoisted(Operation *op, function_ref< bool(OpOperand &)> condition)
Checks whether the given op can be hoisted by checking that.
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.