26#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "linalg-hoisting"
32#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
47 Value newYieldValue) {
50 auto inits = llvm::to_vector(loop.getInits());
53 assert(
index < inits.size());
54 inits[
index] = newInitOperand;
56 scf::ForOp newLoop = scf::ForOp::create(
57 rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
59 loop.getUnsignedCmp());
62 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
63 yieldOp.setOperand(
index, newYieldValue);
66 rewriter.
mergeBlocks(loop.getBody(), newLoop.getBody(),
67 newLoop.getBody()->getArguments());
70 rewriter.
replaceOp(loop.getOperation(), newLoop->getResults());
99 root->
walk([&](vector::ExtractOp extractOp) {
100 LLVM_DEBUG(
DBGS() <<
"Candidate for hoisting: "
101 << *extractOp.getOperation() <<
"\n");
103 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
108 auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
113 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
119 if (!blockArg.hasOneUse())
122 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
126 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
127 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
131 LLVM_DEBUG(
DBGS() <<
"Candidate broadcast: " <<
broadcast <<
"\n");
134 if (broadcastInputType != extractOp.getType())
139 for (
auto operand : extractOp.getDynamicPosition())
140 if (!loop.isDefinedOutsideOfLoop(operand))
144 extractOp.getSourceMutable().assign(initArg->
get());
146 loop.moveOutOfLoop(extractOp);
150 rewriter, loop, extractOp.getResult(),
index,
broadcast.getSource());
152 LLVM_DEBUG(
DBGS() <<
"New loop: " << newLoop <<
"\n");
165 LoopLikeOpInterface loop) {
166 Value source = transferRead.getBase();
169 while (
auto viewLike = source.
getDefiningOp<ViewLikeOpInterface>()) {
170 if (viewLike.getViewDest() != source) {
173 source = viewLike.getViewSource();
178 llvm::SmallDenseSet<Operation *, 32> processed;
179 while (!users.empty()) {
182 if (!processed.insert(user).second)
184 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
185 Value viewDest = viewLike.getViewDest();
191 if (!loop->isAncestor(user))
199 bool verifyNonZeroTrip) {
213 if (verifyNonZeroTrip) {
214 root->
walk([&](LoopLikeOpInterface loopLike) {
215 std::optional<SmallVector<OpFoldResult>> lbs =
216 loopLike.getLoopLowerBounds();
217 std::optional<SmallVector<OpFoldResult>> ubs =
218 loopLike.getLoopUpperBounds();
226 for (
auto [lb,
ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
227 FailureOr<int64_t> maxLb =
234 FailureOr<int64_t> minUb =
239 if (minUb.value() <= maxLb.value())
241 definiteNonZeroTripCountLoops.insert(loopLike);
246 root->
walk([&](vector::TransferReadOp transferRead) {
247 if (!isa<MemRefType>(transferRead.getShapedType()))
250 LLVM_DEBUG(
DBGS() <<
"Candidate for hoisting: "
251 << *transferRead.getOperation() <<
"\n");
252 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
253 LLVM_DEBUG(
DBGS() <<
"Parent op: " << *transferRead->getParentOp()
255 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
258 if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
259 LLVM_DEBUG(
DBGS() <<
"Loop may have zero trip count: " << *loop
264 LLVM_DEBUG(
DBGS() <<
"Candidate read: " << *transferRead.getOperation()
272 vector::TransferWriteOp transferWrite;
273 for (
auto *sliceOp : llvm::reverse(forwardSlice)) {
274 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
275 if (!candidateWrite ||
276 candidateWrite.getBase() != transferRead.getBase())
278 transferWrite = candidateWrite;
282 for (
auto operand : transferRead.getOperands())
283 if (!loop.isDefinedOutsideOfLoop(operand))
288 if (!transferWrite) {
292 loop.moveOutOfLoop(transferRead);
296 LLVM_DEBUG(
DBGS() <<
"Candidate: " << *transferWrite.getOperation()
309 if (transferRead.getIndices() != transferWrite.getIndices() ||
310 transferRead.getVectorType() != transferWrite.getVectorType() ||
311 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
316 auto base = transferRead.getBase();
317 auto *source = base.getDefiningOp();
330 if (
auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
331 Value memPreAlignment = assume.getMemref();
333 llvm::count_if(base.getUses(), [&loop](
OpOperand &use) {
334 return loop->isAncestor(use.getOwner());
337 if (numInLoopUses && memPreAlignment.
hasOneUse())
340 if (isa_and_nonnull<ViewLikeOpInterface>(source))
344 if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
353 for (
auto &use : transferRead.getBase().getUses()) {
354 if (!loop->isAncestor(use.getOwner()))
356 if (use.getOwner() == transferRead.getOperation() ||
357 use.getOwner() == transferWrite.getOperation())
359 if (
auto transferWriteUse =
360 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
362 cast<VectorTransferOpInterface>(*transferWrite),
363 cast<VectorTransferOpInterface>(*transferWriteUse),
366 }
else if (
auto transferReadUse =
367 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
369 cast<VectorTransferOpInterface>(*transferWrite),
370 cast<VectorTransferOpInterface>(*transferReadUse),
381 loop.moveOutOfLoop(transferRead);
384 transferWrite->moveAfter(loop);
388 IRRewriter rewriter(transferRead.getContext());
394 auto maybeNewLoop = loop.replaceWithAdditionalYields(
395 rewriter, transferRead.getVector(),
397 if (failed(maybeNewLoop))
400 transferWrite.getValueToStoreMutable().assign(
401 maybeNewLoop->getOperation()->getResults().back());
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, ValueBoundsOptions options={})
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Include the generated interface declarations.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SetVector< T, Vector, Set, N > SetVector
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
Options that control value bound computation.