MLIR 22.0.0git
Hoisting.cpp
Go to the documentation of this file.
1//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with hoisting invariant operations
10// in the context of Linalg transformations.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/Dominance.h"
26#include "llvm/Support/Debug.h"
27
28using llvm::dbgs;
29
30#define DEBUG_TYPE "linalg-hoisting"
31
32#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
33
34using namespace mlir;
35using namespace mlir::linalg;
36
37/// Replace `loop` with a new loop that has a different init operand at
38/// position `index`. The body of this loop is moved over to the new loop.
39///
40/// `newInitOperands` specifies the replacement "init" operands.
41/// `newYieldValue` is the replacement yield value of the loop at position
42/// `index`.
43static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
44 scf::ForOp loop,
45 Value newInitOperand,
46 unsigned index,
47 Value newYieldValue) {
48 OpBuilder::InsertionGuard g(rewriter);
49 rewriter.setInsertionPoint(loop.getOperation());
50 auto inits = llvm::to_vector(loop.getInits());
51
52 // Replace the init value with the new operand.
53 assert(index < inits.size());
54 inits[index] = newInitOperand;
55
56 scf::ForOp newLoop = scf::ForOp::create(
57 rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
58 loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
59 loop.getUnsignedCmp());
60
61 // Generate the new yield with the replaced operand.
62 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
63 yieldOp.setOperand(index, newYieldValue);
64
65 // Move the loop body to the new op.
66 rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
67 newLoop.getBody()->getArguments());
68
69 // Replace the old loop.
70 rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
71 return newLoop;
72}
73
74// Hoist out a pair of corresponding vector.extract+vector.broadcast
75// operations. This function transforms a loop like this:
76// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
77// %e = vector.extract %iarg : t1 to t2
78// %u = "some_use"(%e) : (t2) -> t2
79// %b = vector.broadcast %u : t2 to t1
80// scf.yield %b : t1
81// }
82// into the following:
83// %e = vector.extract %v: t1 to t2
84// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
85// %u' = "some_use"(%iarg) : (t2) -> t2
86// scf.yield %u' : t2
87// }
88// %res = vector.broadcast %res' : t2 to t1
90 Operation *root) {
91 bool changed = true;
92 while (changed) {
93 changed = false;
94 // First move loop invariant ops outside of their loop. This needs to be
95 // done before as we cannot move ops without interrupting the function walk.
96 root->walk(
97 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
98
99 root->walk([&](vector::ExtractOp extractOp) {
100 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
101 << *extractOp.getOperation() << "\n");
102
103 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
104 if (!loop)
105 return WalkResult::advance();
106
107 // Check that the vector to extract from is a BlockArgument.
108 auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
109 if (!blockArg)
110 return WalkResult::advance();
111
112 // Check that the blockArg is an iter_arg of the loop.
113 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
114 if (!initArg)
115 return WalkResult::advance();
116
117 // If the iter_arg does not have only one use, it won't be possible to
118 // hoist the extractOp out.
119 if (!blockArg.hasOneUse())
120 return WalkResult::advance();
121
122 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
123
124 // Check that the loop yields a broadcast that has just one use.
125 Operation *yieldedVal =
126 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
127 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
128 if (!broadcast || !broadcast.getResult().hasOneUse())
129 return WalkResult::advance();
130
131 LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
132
133 Type broadcastInputType = broadcast.getSourceType();
134 if (broadcastInputType != extractOp.getType())
135 return WalkResult::advance();
136
137 // The position of the extract must be defined outside of the loop if
138 // it is dynamic.
139 for (auto operand : extractOp.getDynamicPosition())
140 if (!loop.isDefinedOutsideOfLoop(operand))
141 return WalkResult::advance();
142
143 rewriter.modifyOpInPlace(broadcast, [&] {
144 extractOp.getSourceMutable().assign(initArg->get());
145 });
146 loop.moveOutOfLoop(extractOp);
147 rewriter.moveOpAfter(broadcast, loop);
148
149 scf::ForOp newLoop = replaceWithDifferentYield(
150 rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
151
152 LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
153
154 rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
155 rewriter.modifyOpInPlace(
156 broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
157
158 changed = true;
159 return WalkResult::interrupt();
160 });
161 }
162}
163
164static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
165 LoopLikeOpInterface loop) {
166 Value source = transferRead.getBase();
167
168 // Skip view-like Ops and retrive the actual soruce Operation
169 while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
170 if (viewLike.getViewDest() != source) {
171 break;
172 }
173 source = viewLike.getViewSource();
174 }
175
176 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
177 source.getUsers().end());
178 llvm::SmallDenseSet<Operation *, 32> processed;
179 while (!users.empty()) {
180 Operation *user = users.pop_back_val();
181 // If the user has already been processed skip.
182 if (!processed.insert(user).second)
183 continue;
184 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
185 Value viewDest = viewLike.getViewDest();
186 users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
187 continue;
188 }
189 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
190 continue;
191 if (!loop->isAncestor(user))
192 continue;
193 return false;
194 }
195 return true;
196}
197
199 bool verifyNonZeroTrip) {
200 bool changed = true;
201 while (changed) {
202 changed = false;
203 // First move loop invariant ops outside of their loop. This needs to be
204 // done before as we cannot move ops without interrupting the function walk.
205 root->walk(
206 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
207
208 // Find all loops that are certain to have non zero trip count. Any loops
209 // that are not part of this set cannot be hoisted from, since hoisting from
210 // a potentially zero trip count loop may cause a vector transfer to be
211 // executed when it shouldn't be.
212 llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
213 if (verifyNonZeroTrip) {
214 root->walk([&](LoopLikeOpInterface loopLike) {
215 std::optional<SmallVector<OpFoldResult>> lbs =
216 loopLike.getLoopLowerBounds();
217 std::optional<SmallVector<OpFoldResult>> ubs =
218 loopLike.getLoopUpperBounds();
219 // If loop bounds cannot be found, assume possibly zero trip count.
220 if (!lbs || !ubs)
221 return;
222
223 // Otherwise, use ValueBounds to find the maximum lower bound and
224 // minimum upper bound. If the bounds are found, and maxLb is less
225 // than the minUb, then the loop will not have zero trip count.
226 for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
227 FailureOr<int64_t> maxLb =
230 /*stopCondition=*/nullptr, /*closedUB=*/true);
231 if (failed(maxLb))
232 return;
233 FailureOr<int64_t> minUb =
236 if (failed(minUb))
237 return;
238 if (minUb.value() <= maxLb.value())
239 return;
240 definiteNonZeroTripCountLoops.insert(loopLike);
241 }
242 });
243 }
244
245 root->walk([&](vector::TransferReadOp transferRead) {
246 if (!isa<MemRefType>(transferRead.getShapedType()))
247 return WalkResult::advance();
248
249 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
250 << *transferRead.getOperation() << "\n");
251 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
252 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
253 << "\n");
254 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
255 return WalkResult::advance();
256
257 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
258 LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
259 << "\n");
260 return WalkResult::advance();
261 }
262
263 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
264 << "\n");
265
266 SetVector<Operation *> forwardSlice;
267 getForwardSlice(transferRead.getOperation(), &forwardSlice);
268
269 // Look for the last TransferWriteOp in the forwardSlice of
270 // `transferRead` that operates on the same memref.
271 vector::TransferWriteOp transferWrite;
272 for (auto *sliceOp : llvm::reverse(forwardSlice)) {
273 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
274 if (!candidateWrite ||
275 candidateWrite.getBase() != transferRead.getBase())
276 continue;
277 transferWrite = candidateWrite;
278 }
279
280 // All operands of the TransferRead must be defined outside of the loop.
281 for (auto operand : transferRead.getOperands())
282 if (!loop.isDefinedOutsideOfLoop(operand))
283 return WalkResult::advance();
284
285 // Only hoist transfer_read / transfer_write pairs and singleton
286 // transfer_reads for now.
287 if (!transferWrite) {
288 // Make sure there are no other accesses to the memref before
289 // hoisting transfer_read.
290 if (noAliasingUseInLoop(transferRead, loop))
291 loop.moveOutOfLoop(transferRead);
292 return WalkResult::advance();
293 }
294
295 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
296 << "\n");
297
298 // Approximate aliasing by checking that:
299 // 1. indices, vector type and permutation map are the same (i.e., the
300 // transfer_read/transfer_write ops are matching),
301 // 2. source operands for transfer.{read|write} do not originate from
302 // nor have users that are Ops implementing ViewLikeOpInterface.
303 // 3. no other operations in the loop access the same memref except
304 // for transfer_read/transfer_write accessing statically disjoint
305 // slices.
306
307 // Check 1.
308 if (transferRead.getIndices() != transferWrite.getIndices() ||
309 transferRead.getVectorType() != transferWrite.getVectorType() ||
310 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
311 return WalkResult::advance();
312
313 // Check 2. Note, since both xfer Ops share the source, we only need to
314 // look at one of them.
315 auto base = transferRead.getBase();
316 auto *source = base.getDefiningOp();
317 if (source) {
318 // NOTE: We treat `memref.assume_alignment` as a special case.
319 //
320 // The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
321 // MemRef _before_ alignment) iff:
322 // 1. It has exactly two uses (these have to be the xfer Ops
323 // being looked at).
324 // 2. The original MemRef has only one use (i.e.
325 // AssumeAlignmentOp).
326 //
327 // Relaxing these conditions will most likely require proper alias
328 // analysis.
329 if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
330 Value memPreAlignment = assume.getMemref();
331 auto numInLoopUses =
332 llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
333 return loop->isAncestor(use.getOwner());
334 });
335
336 if (numInLoopUses && memPreAlignment.hasOneUse())
337 source = memPreAlignment.getDefiningOp();
338 }
339 if (isa_and_nonnull<ViewLikeOpInterface>(source))
340 return WalkResult::advance();
341 }
342
343 if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
344 return WalkResult::advance();
345
346 // Check 3.
347 // TODO: may want to memoize this information for performance but it
348 // likely gets invalidated often.
349 DominanceInfo dom(loop);
350 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
351 return WalkResult::advance();
352 for (auto &use : transferRead.getBase().getUses()) {
353 if (!loop->isAncestor(use.getOwner()))
354 continue;
355 if (use.getOwner() == transferRead.getOperation() ||
356 use.getOwner() == transferWrite.getOperation())
357 continue;
358 if (auto transferWriteUse =
359 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
361 cast<VectorTransferOpInterface>(*transferWrite),
362 cast<VectorTransferOpInterface>(*transferWriteUse),
363 /*testDynamicValueUsingBounds=*/true))
364 return WalkResult::advance();
365 } else if (auto transferReadUse =
366 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
368 cast<VectorTransferOpInterface>(*transferWrite),
369 cast<VectorTransferOpInterface>(*transferReadUse),
370 /*testDynamicValueUsingBounds=*/true))
371 return WalkResult::advance();
372 } else {
373 // Unknown use, we cannot prove that it doesn't alias with the
374 // transferRead/transferWrite operations.
375 return WalkResult::advance();
376 }
377 }
378
379 // Hoist read before.
380 loop.moveOutOfLoop(transferRead);
381
382 // Hoist write after.
383 transferWrite->moveAfter(loop);
384
385 // Rewrite `loop` with new yields by cloning and erase the original
386 // loop.
387 IRRewriter rewriter(transferRead.getContext());
388 NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
389 ArrayRef<BlockArgument> newBBArgs) {
390 return SmallVector<Value>{transferWrite.getVector()};
391 };
392
393 auto maybeNewLoop = loop.replaceWithAdditionalYields(
394 rewriter, transferRead.getVector(),
395 /*replaceInitOperandUsesInLoop=*/true, yieldFn);
396 if (failed(maybeNewLoop))
397 return WalkResult::interrupt();
398
399 transferWrite.getValueToStoreMutable().assign(
400 maybeNewLoop->getOperation()->getResults().back());
401 changed = true;
402 // Need to interrupt and restart because erasing the loop messes up
403 // the walk.
404 return WalkResult::interrupt();
405 });
406 }
407}
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
Definition Hoisting.cpp:164
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
Definition Hoisting.cpp:43
#define DBGS()
Definition Hoisting.cpp:32
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.