33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
39 #define DEBUG_TYPE "linalg-hoisting"
41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
56 Value newYieldValue) {
59 auto inits = llvm::to_vector(loop.getInits());
62 assert(index < inits.size());
63 inits[index] = newInitOperand;
65 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
66 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
70 auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71 yieldOp.setOperand(index, newYieldValue);
74 rewriter.
mergeBlocks(loop.getBody(), newLoop.getBody(),
75 newLoop.getBody()->getArguments());
78 rewriter.
replaceOp(loop.getOperation(), newLoop->getResults());
107 root->
walk([&](vector::ExtractOp extractOp) {
108 LLVM_DEBUG(
DBGS() <<
"Candidate for hoisting: "
109 << *extractOp.getOperation() <<
"\n");
111 auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
116 auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
121 OpOperand *initArg = loop.getTiedLoopInit(blockArg);
127 if (!blockArg.hasOneUse())
130 unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
134 loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135 auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
139 LLVM_DEBUG(
DBGS() <<
"Candidate broadcast: " <<
broadcast <<
"\n");
142 if (broadcastInputType != extractOp.getType())
147 for (
auto operand : extractOp.getDynamicPosition())
148 if (!loop.isDefinedOutsideOfLoop(operand))
152 extractOp.getVectorMutable().assign(initArg->
get());
154 loop.moveOutOfLoop(extractOp);
158 rewriter, loop, extractOp.getResult(), index,
broadcast.getSource());
160 LLVM_DEBUG(
DBGS() <<
"New loop: " << newLoop <<
"\n");
173 LoopLikeOpInterface loop) {
174 Value source = transferRead.getSource();
178 dyn_cast_or_null<ViewLikeOpInterface>(source.
getDefiningOp()))
179 source = srcOp.getViewSource();
183 llvm::SmallDenseSet<Operation *, 32> processed;
184 while (!users.empty()) {
187 if (!processed.insert(user).second)
189 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
190 users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
195 if (!loop->isAncestor(user))
211 root->
walk([&](vector::TransferReadOp transferRead) {
212 if (!isa<MemRefType>(transferRead.getShapedType()))
215 LLVM_DEBUG(
DBGS() <<
"Candidate for hoisting: "
216 << *transferRead.getOperation() <<
"\n");
217 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
218 LLVM_DEBUG(
DBGS() <<
"Parent op: " << *transferRead->getParentOp()
220 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
223 LLVM_DEBUG(
DBGS() <<
"Candidate read: " << *transferRead.getOperation()
231 vector::TransferWriteOp transferWrite;
232 for (
auto *sliceOp : llvm::reverse(forwardSlice)) {
233 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
234 if (!candidateWrite ||
235 candidateWrite.getSource() != transferRead.getSource())
237 transferWrite = candidateWrite;
241 for (
auto operand : transferRead.getOperands())
242 if (!loop.isDefinedOutsideOfLoop(operand))
247 if (!transferWrite) {
251 loop.moveOutOfLoop(transferRead);
255 LLVM_DEBUG(
DBGS() <<
"Candidate: " << *transferWrite.getOperation()
266 if (transferRead.getIndices() != transferWrite.getIndices() ||
267 transferRead.getVectorType() != transferWrite.getVectorType() ||
268 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
271 auto *source = transferRead.getSource().getDefiningOp();
272 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
275 source = transferWrite.getSource().getDefiningOp();
276 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
284 for (
auto &use : transferRead.getSource().getUses()) {
285 if (!loop->isAncestor(use.getOwner()))
287 if (use.getOwner() == transferRead.getOperation() ||
288 use.getOwner() == transferWrite.getOperation())
290 if (
auto transferWriteUse =
291 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
293 cast<VectorTransferOpInterface>(*transferWrite),
294 cast<VectorTransferOpInterface>(*transferWriteUse),
297 }
else if (
auto transferReadUse =
298 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
300 cast<VectorTransferOpInterface>(*transferWrite),
301 cast<VectorTransferOpInterface>(*transferReadUse),
312 loop.moveOutOfLoop(transferRead);
315 transferWrite->moveAfter(loop);
318 IRRewriter rewriter(transferRead.getContext());
324 auto maybeNewLoop = loop.replaceWithAdditionalYields(
325 rewriter, transferRead.getVector(),
327 if (failed(maybeNewLoop))
330 transferWrite.getVectorMutable().assign(
331 maybeNewLoop->getOperation()->getResults().back());
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop, Value newInitOperand, unsigned index, Value newYieldValue)
Replace loop with a new loop that has a different init operand at position index.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.