MLIR 23.0.0git
Hoisting.cpp
Go to the documentation of this file.
1//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting invariant operations
10// in the context of Linalg transformations.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/Dominance.h"
26#include "llvm/Support/Debug.h"
27
28using llvm::dbgs;
29
30#define DEBUG_TYPE "linalg-hoisting"
31
32#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
33
34using namespace mlir;
35using namespace mlir::linalg;
36
37/// Replace `loop` with a new loop that has a different init operand at
38/// position `index`. The body of this loop is moved over to the new loop.
39///
40/// `newInitOperands` specifies the replacement "init" operands.
41/// `newYieldValue` is the replacement yield value of the loop at position
42/// `index`.
43static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
44 scf::ForOp loop,
45 Value newInitOperand,
46 unsigned index,
47 Value newYieldValue) {
48 OpBuilder::InsertionGuard g(rewriter);
49 rewriter.setInsertionPoint(loop.getOperation());
50 auto inits = llvm::to_vector(loop.getInits());
51
52 // Replace the init value with the new operand.
53 assert(index < inits.size());
54 inits[index] = newInitOperand;
55
56 scf::ForOp newLoop = scf::ForOp::create(
57 rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
58 loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
59 loop.getUnsignedCmp());
60
61 // Generate the new yield with the replaced operand.
62 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
63 yieldOp.setOperand(index, newYieldValue);
64
65 // Move the loop body to the new op.
66 rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
67 newLoop.getBody()->getArguments());
68
69 // Replace the old loop.
70 rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
71 return newLoop;
72}
73
74// Hoist out a pair of corresponding vector.extract+vector.broadcast
75// operations. This function transforms a loop like this:
76// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
77// %e = vector.extract %iarg : t1 to t2
78// %u = "some_use"(%e) : (t2) -> t2
79// %b = vector.broadcast %u : t2 to t1
80// scf.yield %b : t1
81// }
82// into the following:
83// %e = vector.extract %v: t1 to t2
84// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
85// %u' = "some_use"(%iarg) : (t2) -> t2
86// scf.yield %u' : t2
87// }
88// %res = vector.broadcast %res' : t2 to t1
90 Operation *root) {
91 bool changed = true;
92 while (changed) {
93 changed = false;
94 // First move loop invariant ops outside of their loop. This needs to be
95 // done before as we cannot move ops without interrupting the function walk.
96 root->walk(
97 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
98
99 root->walk([&](vector::ExtractOp extractOp) {
100 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
101 << *extractOp.getOperation() << "\n");
102
103 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
104 if (!loop)
105 return WalkResult::advance();
106
107 // Check that the vector to extract from is a BlockArgument.
108 auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
109 if (!blockArg)
110 return WalkResult::advance();
111
112 // Check that the blockArg is an iter_arg of the loop.
113 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
114 if (!initArg)
115 return WalkResult::advance();
116
117 // If the iter_arg does not have only one use, it won't be possible to
118 // hoist the extractOp out.
119 if (!blockArg.hasOneUse())
120 return WalkResult::advance();
121
122 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
123
124 // Check that the loop yields a broadcast that has just one use.
125 Operation *yieldedVal =
126 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
127 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
128 if (!broadcast || !broadcast.getResult().hasOneUse())
129 return WalkResult::advance();
130
131 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
132
133 Type broadcastInputType = broadcast.getSourceType();
134 if (broadcastInputType != extractOp.getType())
135 return WalkResult::advance();
136
137 // The position of the extract must be defined outside of the loop if
138 // it is dynamic.
139 for (auto operand : extractOp.getDynamicPosition())
140 if (!loop.isDefinedOutsideOfLoop(operand))
141 return WalkResult::advance();
142
143 rewriter.modifyOpInPlace(broadcast, [&] {
144 extractOp.getSourceMutable().assign(initArg->get());
145 });
146 loop.moveOutOfLoop(extractOp);
147 rewriter.moveOpAfter(broadcast, loop);
148
149 scf::ForOp newLoop = replaceWithDifferentYield(
150 rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
151
152 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
153
154 rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
155 rewriter.modifyOpInPlace(
156 broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
157
158 changed = true;
159 return WalkResult::interrupt();
160 });
161 }
162}
163
164static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
165 LoopLikeOpInterface loop) {
166 Value source = transferRead.getBase();
167
168 // Skip view-like Ops and retrive the actual soruce Operation
169 while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
170 if (viewLike.getViewDest() != source) {
171 break;
172 }
173 source = viewLike.getViewSource();
174 }
175
176 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
177 source.getUsers().end());
178 llvm::SmallDenseSet<Operation *, 32> processed;
179 while (!users.empty()) {
180 Operation *user = users.pop_back_val();
181 // If the user has already been processed skip.
182 if (!processed.insert(user).second)
183 continue;
184 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
185 Value viewDest = viewLike.getViewDest();
186 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
187 continue;
188 }
189 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
190 continue;
191 if (!loop->isAncestor(user))
192 continue;
193 return false;
194 }
195 return true;
196}
197
199 bool verifyNonZeroTrip) {
200 bool changed = true;
201 while (changed) {
202 changed = false;
203 // First move loop invariant ops outside of their loop. This needs to be
204 // done before as we cannot move ops without interrupting the function walk.
205 root->walk(
206 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
207
208 // Find all loops that are certain to have non zero trip count. Any loops
209 // that are not part of this set cannot be hoisted from, since hoisting from
210 // a potentially zero trip count loop may cause a vector transfer to be
211 // executed when it shouldn't be.
212 llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
213 if (verifyNonZeroTrip) {
214 root->walk([&](LoopLikeOpInterface loopLike) {
215 std::optional<SmallVector<OpFoldResult>> lbs =
216 loopLike.getLoopLowerBounds();
217 std::optional<SmallVector<OpFoldResult>> ubs =
218 loopLike.getLoopUpperBounds();
219 // If loop bounds cannot be found, assume possibly zero trip count.
220 if (!lbs || !ubs)
221 return;
222
223 // Otherwise, use ValueBounds to find the maximum lower bound and
224 // minimum upper bound. If the bounds are found, and maxLb is less
225 // than the minUb, then the loop will not have zero trip count.
226 for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
227 FailureOr<int64_t> maxLb =
230 /*stopCondition=*/nullptr,
231 ValueBoundsOptions{/*closedUB=*/true});
232 if (failed(maxLb))
233 return;
234 FailureOr<int64_t> minUb =
237 if (failed(minUb))
238 return;
239 if (minUb.value() <= maxLb.value())
240 return;
241 definiteNonZeroTripCountLoops.insert(loopLike);
242 }
243 });
244 }
245
246 root->walk([&](vector::TransferReadOp transferRead) {
247 if (!isa<MemRefType>(transferRead.getShapedType()))
248 return WalkResult::advance();
249
250 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
251 << *transferRead.getOperation() << "\n");
252 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
253 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
254 << "\n");
255 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
256 return WalkResult::advance();
257
258 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
259 LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
260 << "\n");
261 return WalkResult::advance();
262 }
263
264 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
265 << "\n");
266
267 SetVector<Operation *> forwardSlice;
268 getForwardSlice(transferRead.getOperation(), &forwardSlice);
269
270 // Look for the last TransferWriteOp in the forwardSlice of
271 // `transferRead` that operates on the same memref.
272 vector::TransferWriteOp transferWrite;
273 for (auto *sliceOp : llvm::reverse(forwardSlice)) {
274 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
275 if (!candidateWrite ||
276 candidateWrite.getBase() != transferRead.getBase())
277 continue;
278 transferWrite = candidateWrite;
279 }
280
281 // All operands of the TransferRead must be defined outside of the loop.
282 for (auto operand : transferRead.getOperands())
283 if (!loop.isDefinedOutsideOfLoop(operand))
284 return WalkResult::advance();
285
286 // Only hoist transfer_read / transfer_write pairs and singleton
287 // transfer_reads for now.
288 if (!transferWrite) {
289 // Make sure there are no other accesses to the memref before
290 // hoisting transfer_read.
291 if (noAliasingUseInLoop(transferRead, loop))
292 loop.moveOutOfLoop(transferRead);
293 return WalkResult::advance();
294 }
295
296 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
297 << "\n");
298
299 // Approximate aliasing by checking that:
300 // 1. indices, vector type and permutation map are the same (i.e., the
301 // transfer_read/transfer_write ops are matching),
302 // 2. source operands for transfer.{read|write} do not originate from
303 // nor have users that are Ops implementing ViewLikeOpInterface.
304 // 3. no other operations in the loop access the same memref except
305 // for transfer_read/transfer_write accessing statically disjoint
306 // slices.
307
308 // Check 1.
309 if (transferRead.getIndices() != transferWrite.getIndices() ||
310 transferRead.getVectorType() != transferWrite.getVectorType() ||
311 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
312 return WalkResult::advance();
313
314 // Check 2. Note, since both xfer Ops share the source, we only need to
315 // look at one of them.
316 auto base = transferRead.getBase();
317 auto *source = base.getDefiningOp();
318 if (source) {
319 // NOTE: We treat `memref.assume_alignment` as a special case.
320 //
321 // The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
322 // MemRef _before_ alignment) iff:
323 // 1. It has exactly two uses (these have to be the xfer Ops
324 // being looked at).
325 // 2. The original MemRef has only one use (i.e.
326 // AssumeAlignmentOp).
327 //
328 // Relaxing these conditions will most likely require proper alias
329 // analysis.
330 if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
331 Value memPreAlignment = assume.getMemref();
332 auto numInLoopUses =
333 llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
334 return loop->isAncestor(use.getOwner());
335 });
336
337 if (numInLoopUses && memPreAlignment.hasOneUse())
338 source = memPreAlignment.getDefiningOp();
339 }
340 if (isa_and_nonnull<ViewLikeOpInterface>(source))
341 return WalkResult::advance();
342 }
343
344 if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
345 return WalkResult::advance();
346
347 // Check 3.
348 // TODO: may want to memoize this information for performance but it
349 // likely gets invalidated often.
350 DominanceInfo dom(loop);
351 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
352 return WalkResult::advance();
353 for (auto &use : transferRead.getBase().getUses()) {
354 if (!loop->isAncestor(use.getOwner()))
355 continue;
356 if (use.getOwner() == transferRead.getOperation() ||
357 use.getOwner() == transferWrite.getOperation())
358 continue;
359 if (auto transferWriteUse =
360 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
362 cast<VectorTransferOpInterface>(*transferWrite),
363 cast<VectorTransferOpInterface>(*transferWriteUse),
364 /*testDynamicValueUsingBounds=*/true))
365 return WalkResult::advance();
366 } else if (auto transferReadUse =
367 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
369 cast<VectorTransferOpInterface>(*transferWrite),
370 cast<VectorTransferOpInterface>(*transferReadUse),
371 /*testDynamicValueUsingBounds=*/true))
372 return WalkResult::advance();
373 } else {
374 // Unknown use, we cannot prove that it doesn't alias with the
375 // transferRead/transferWrite operations.
376 return WalkResult::advance();
377 }
378 }
379
380 // Hoist read before.
381 loop.moveOutOfLoop(transferRead);
382
383 // Hoist write after.
384 transferWrite->moveAfter(loop);
385
386 // Rewrite `loop` with new yields by cloning and erase the original
387 // loop.
388 IRRewriter rewriter(transferRead.getContext());
389 NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
390 ArrayRef<BlockArgument> newBBArgs) {
391 return SmallVector<Value>{transferWrite.getVector()};
392 };
393
394 auto maybeNewLoop = loop.replaceWithAdditionalYields(
395 rewriter, transferRead.getVector(),
396 /*replaceInitOperandUsesInLoop=*/true, yieldFn);
397 if (failed(maybeNewLoop))
398 return WalkResult::interrupt();
399
400 transferWrite.getValueToStoreMutable().assign(
401 maybeNewLoop->getOperation()->getResults().back());
402 changed = true;
403 // Need to interrupt and restart because erasing the loop messes up
404 // the walk.
405 return WalkResult::interrupt();
406 });
407 }
408}
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
Definition Hoisting.cpp:164
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
Definition Hoisting.cpp:43
#define DBGS()
Definition Hoisting.cpp:32
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
Definition Dominance.h:143
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, ValueBoundsOptions options={})
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Include the generated interface declarations.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Options that control value bound computation.