MLIR 22.0.0git
LoopInvariantCodeMotionUtils.cpp
Go to the documentation of this file.
1//===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the core LICM algorithm.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "mlir/IR/Operation.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/DebugLog.h"
23#include <queue>
24
25#define DEBUG_TYPE "licm"
26
27using namespace mlir;
28
29/// Checks whether the given op can be hoisted by checking that
30/// - the op and none of its contained operations depend on values inside of the
31/// loop (by means of calling definedOutside).
32/// - the op has no side-effects.
33static bool canBeHoisted(Operation *op,
34 function_ref<bool(OpOperand &)> condition) {
35 // Do not move terminators.
37 return false;
38
39 // Walk the nested operations and check that all used values are either
40 // defined outside of the loop or in a nested region, but not at the level of
41 // the loop body.
42 auto walkFn = [&](Operation *child) {
43 for (OpOperand &operand : child->getOpOperands()) {
44 // Ignore values defined in a nested region.
45 if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
46 continue;
47 if (!condition(operand))
48 return WalkResult::interrupt();
49 }
50 return WalkResult::advance();
51 };
52 return !op->walk(walkFn).wasInterrupted();
53}
54
55static bool canBeHoisted(Operation *op,
56 function_ref<bool(Value)> definedOutside) {
57 return canBeHoisted(
58 op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
59}
60
62 ArrayRef<Region *> regions,
63 function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
64 function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
65 function_ref<void(Operation *, Region *)> moveOutOfRegion) {
66 size_t numMoved = 0;
67
68 for (Region *region : regions) {
69 LDBG() << "Original loop:\n"
70 << OpWithFlags(region->getParentOp(),
71 OpPrintingFlags().skipRegions());
72
73 std::queue<Operation *> worklist;
74 // Add top-level operations in the loop body to the worklist.
75 for (Operation &op : region->getOps())
76 worklist.push(&op);
77
78 auto definedOutside = [&](Value value) {
79 return isDefinedOutsideRegion(value, region);
80 };
81
82 while (!worklist.empty()) {
83 Operation *op = worklist.front();
84 worklist.pop();
85 // Skip ops that have already been moved. Check if the op can be hoisted.
86 if (op->getParentRegion() != region)
87 continue;
88
89 LDBG() << "Checking op: "
90 << OpWithFlags(op, OpPrintingFlags().skipRegions());
91 if (!shouldMoveOutOfRegion(op, region) ||
92 !canBeHoisted(op, definedOutside))
93 continue;
94
95 LDBG() << "Moving loop-invariant op: "
96 << OpWithFlags(op, OpPrintingFlags().skipRegions());
97 moveOutOfRegion(op, region);
98 ++numMoved;
99
100 // Since the op has been moved, we need to check its users within the
101 // top-level of the loop body.
102 for (Operation *user : op->getUsers())
103 if (user->getParentRegion() == region)
104 worklist.push(user);
105 }
106 }
107
108 return numMoved;
109}
110
111size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
113 loopLike.getLoopRegions(),
114 [&](Value value, Region *) {
115 return loopLike.isDefinedOutsideOfLoop(value);
116 },
117 [&](Operation *op, Region *) { return isPure(op); },
118 [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
119}
120
121namespace {
122/// Helper data structure that keeps track of equivalent/disjoint subset ops.
123class MatchingSubsets {
124public:
125 /// Insert a subset op.
126 void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
127 allSubsetOps.push_back(op);
128 if (!collectHoistableOps)
129 return;
130 if (auto extractionOp =
131 dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
132 insertExtractionOp(extractionOp);
133 if (auto insertionOp =
134 dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
135 insertInsertionOp(insertionOp);
136 }
137
138 /// Return a range of matching extraction-insertion subset ops. If there is no
139 /// matching extraction/insertion op, the respective value is empty. Ops are
140 /// skipped if there are other subset ops that are not guaranteed to operate
141 /// on disjoint subsets.
142 auto getHoistableSubsetOps() {
143 return llvm::make_filter_range(
144 llvm::zip(extractions, insertions), [&](auto pair) {
145 auto [extractionOp, insertionOp] = pair;
146 // Hoist only if the extracted and inserted values have the same type.
147 if (extractionOp && insertionOp &&
148 extractionOp->getResult(0).getType() !=
149 insertionOp.getSourceOperand().get().getType())
150 return false;
151 // Hoist only if there are no conflicting subset ops.
152 return allDisjoint(extractionOp, insertionOp);
153 });
154 }
155
156 /// Populate subset ops starting from the given region iter_arg. Return
157 /// "failure" if non-subset ops are found along the path to the loop yielding
158 /// op or if there is no single path to the tied yielded operand. If
159 /// `collectHoistableOps` is set to "false", subset ops are gathered
160 /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
161 LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
162 BlockArgument iterArg,
163 bool collectHoistableOps = true);
164
165private:
166 /// Helper function for equivalence of tensor values. Since only insertion
167 /// subset ops (that are also destination style ops) are followed when
168 /// traversing the SSA use-def chain, all tensor values are equivalent.
169 static bool isEquivalent(Value v1, Value v2) { return true; }
170
171 /// Return "true" if the subsets of the given extraction and insertion ops
172 /// are operating disjoint from the subsets that all other known subset ops
173 /// are operating on.
174 bool allDisjoint(SubsetExtractionOpInterface extractionOp,
175 SubsetInsertionOpInterface insertionOp) const {
176 for (SubsetOpInterface other : allSubsetOps) {
177 if (other == extractionOp || other == insertionOp)
178 continue;
179 if (extractionOp &&
180 !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
181 return false;
182 if (insertionOp &&
183 !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
184 return false;
185 }
186 return true;
187 }
188
189 /// Insert a subset extraction op. If the subset is equivalent to an existing
190 /// subset insertion op, pair them up. (If there is already a paired up subset
191 /// extraction op, overwrite the subset extraction op.)
192 void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
193 for (auto it : llvm::enumerate(insertions)) {
194 if (!it.value())
195 continue;
196 auto other = cast<SubsetOpInterface>(it.value().getOperation());
197 if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
198 extractions[it.index()] = extractionOp;
199 return;
200 }
201 }
202 // There is no known equivalent insertion op. Create a new entry.
203 extractions.push_back(extractionOp);
204 insertions.push_back({});
205 }
206
207 /// Insert a subset insertion op. If the subset is equivalent to an existing
208 /// subset extraction op, pair them up. (If there is already a paired up
209 /// subset insertion op, overwrite the subset insertion op.)
210 void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
211 for (auto it : llvm::enumerate(extractions)) {
212 if (!it.value())
213 continue;
214 auto other = cast<SubsetOpInterface>(it.value().getOperation());
215 if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
216 insertions[it.index()] = insertionOp;
217 return;
218 }
219 }
220 // There is no known equivalent extraction op. Create a new entry.
221 extractions.push_back({});
222 insertions.push_back(insertionOp);
223 }
224
225 SmallVector<SubsetExtractionOpInterface> extractions;
226 SmallVector<SubsetInsertionOpInterface> insertions;
227 SmallVector<SubsetOpInterface> allSubsetOps;
228};
229} // namespace
230
231/// If the given value has a single use by an op that is a terminator, return
232/// that use. Otherwise, return nullptr.
234 if (!value.hasOneUse())
235 return nullptr;
236 OpOperand &use = *value.getUses().begin();
238 return &use;
239 return nullptr;
240}
241
242LogicalResult
243MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
244 BlockArgument iterArg,
245 bool collectHoistableOps) {
246 assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
247 Value value = iterArg;
248
249 // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
250 // use-def chain starting from the region iter_arg are subset extraction or
251 // subset insertion ops. The chain must terminate at the corresponding yield
252 // operand (e.g., no swapping of iter_args).
253 OpOperand *yieldedOperand = nullptr;
254 // Iterate until the single use of the current SSA value is a terminator,
255 // which is expected to be the yielding operation of the loop.
256 while (!(yieldedOperand = getSingleTerminatorUse(value))) {
257 Value nextValue = {};
258
259 for (OpOperand &use : value.getUses()) {
260 if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
261 // Subset ops in nested loops are collected to check if there are only
262 // disjoint subset ops, but such subset ops are not subject to hoisting.
263 // To hoist subset ops from nested loops, the hoisting transformation
264 // should be run on the nested loop.
265 auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
266 if (!nestedIterArg)
267 return failure();
268 // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
269 // use-def chain starting at `nestedIterArg` and terminating in the
270 // tied, yielding operand.
271 if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
272 /*collectHoistableOps=*/false)))
273 return failure();
274 nextValue = nestedLoop.getTiedLoopResult(&use);
275 continue;
276 }
277
278 auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
279 if (!subsetOp)
280 return failure();
281 insert(subsetOp);
282
283 if (auto insertionOp =
284 dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
285 // Current implementation expects that the insertionOp implement
286 // the DestinationStyleOpInterface and with pure tensor semantics
287 // as well. Abort if that is not the case.
288 auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner());
289 if (!dstOp || !dstOp.hasPureTensorSemantics())
290 return failure();
291
292 // The value must be used as a destination. (In case of a source, the
293 // entire tensor would be read, which would prevent any hoisting.)
294 if (&use != &insertionOp.getDestinationOperand())
295 return failure();
296 // There must be a single use-def chain from the region iter_arg to the
297 // terminator. I.e., only one insertion op. Branches are not supported.
298 if (nextValue)
299 return failure();
300 nextValue = insertionOp.getUpdatedDestination();
301 }
302 }
303
304 // Nothing can be hoisted if the chain does not continue with loop yielding
305 // op or a subset insertion op.
306 if (!nextValue)
307 return failure();
308 value = nextValue;
309 }
310
311 // Hoist only if the SSA use-def chain ends in the yielding terminator of the
312 // loop and the yielded value is the `idx`-th operand. (I.e., there is no
313 // swapping yield.)
314 if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
315 return failure();
316
317 return success();
318}
319
320/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
321/// loop-like op and index into loop-invariant subset locations. Return the
322/// newly created loop op (that has extra iter_args) or the original loop op if
323/// nothing was hoisted.
324static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
325 LoopLikeOpInterface loopLike,
326 BlockArgument iterArg) {
327 assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
328 BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
329 int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
330 MatchingSubsets subsets;
331 if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
332 return loopLike;
333
334 // Hoist all matching extraction-insertion pairs one-by-one.
335 for (auto it : subsets.getHoistableSubsetOps()) {
336 auto extractionOp = std::get<0>(it);
337 auto insertionOp = std::get<1>(it);
338
339 // Ops cannot be hoisted if they depend on loop-variant values.
340 if (extractionOp) {
341 if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
342 return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
343 &operand == &extractionOp.getSourceOperand();
344 }))
345 extractionOp = {};
346 }
347 if (insertionOp) {
348 if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
349 return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
350 &operand == &insertionOp.getSourceOperand() ||
351 &operand == &insertionOp.getDestinationOperand();
352 }))
353 insertionOp = {};
354 }
355
356 // Only hoist extraction-insertion pairs for now. Standalone extractions/
357 // insertions that are loop-invariant could be hoisted, but there may be
358 // easier ways to canonicalize the IR.
359 if (extractionOp && insertionOp) {
360 // Create a new loop with an additional iter_arg.
361 NewYieldValuesFn newYieldValuesFn =
362 [&](OpBuilder &b, Location loc,
364 return {insertionOp.getSourceOperand().get()};
365 };
366 FailureOr<LoopLikeOpInterface> newLoop =
367 loopLike.replaceWithAdditionalYields(
368 rewriter, extractionOp.getResult(),
369 /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
370 if (failed(newLoop))
371 return loopLike;
372 loopLike = *newLoop;
373
374 // Hoist the extraction/insertion ops.
375 iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
376 OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
377 OpResult newLoopResult = loopLike.getLoopResults()->back();
378 rewriter.moveOpBefore(extractionOp, loopLike);
379 rewriter.moveOpAfter(insertionOp, loopLike);
380 rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
381 insertionOp.getDestinationOperand().get());
382 extractionOp.getSourceOperand().set(
383 loopLike.getTiedLoopInit(iterArg)->get());
384 rewriter.replaceAllUsesWith(loopResult,
385 insertionOp.getUpdatedDestination());
386 insertionOp.getSourceOperand().set(newLoopResult);
387 insertionOp.getDestinationOperand().set(loopResult);
388 }
389 }
390
391 return loopLike;
392}
393
394LoopLikeOpInterface
396 LoopLikeOpInterface loopLike) {
397 // Note: As subset ops are getting hoisted, the number of region iter_args
398 // increases. This can enable further hoisting opportunities on the new
399 // iter_args.
400 for (int64_t i = 0;
401 i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
402 loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
403 loopLike.getRegionIterArgs()[i]);
404 }
405 return loopLike;
406}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static OpOperand * getSingleTerminatorUse(Value value)
If the given value has a single use by an op that is a terminator, return that use.
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg)
Hoist all subset ops that operate on the idx-th region iter_arg of the given loop-like op and index i...
static bool canBeHoisted(Operation *op, function_ref< bool(OpOperand &)> condition)
Checks whether the given op can be hoisted by checking that.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
This class represents an argument of a Block.
Definition Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition Value.h:457
This class provides the API for ops that are known to be terminators.
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition Operation.h:1111
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
LoopLikeOpInterface hoistLoopInvariantSubsets(RewriterBase &rewriter, LoopLikeOpInterface loopLike)
Hoist loop-invariant tensor subsets (subset extraction and subset insertion ops) from loop-like ops.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.