33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
39 #define DEBUG_TYPE "linalg-hoisting"
41 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
47 LoopLikeOpInterface loop) {
48 Value source = transferRead.getSource();
52 dyn_cast_or_null<ViewLikeOpInterface>(source.
getDefiningOp()))
53 source = srcOp.getViewSource();
57 llvm::SmallDenseSet<Operation *, 32> processed;
58 while (!users.empty()) {
61 if (!processed.insert(user).second)
63 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
64 users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
69 if (!loop->isAncestor(user))
85 func.walk([&](vector::TransferReadOp transferRead) {
86 if (!isa<MemRefType>(transferRead.getShapedType()))
89 LLVM_DEBUG(
DBGS() <<
"Candidate for hoisting: "
90 << *transferRead.getOperation() <<
"\n");
91 auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
92 LLVM_DEBUG(
DBGS() <<
"Parent op: " << *transferRead->getParentOp()
94 if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
97 LLVM_DEBUG(
DBGS() <<
"Candidate read: " << *transferRead.getOperation()
105 vector::TransferWriteOp transferWrite;
106 for (
auto *sliceOp : llvm::reverse(forwardSlice)) {
107 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
108 if (!candidateWrite ||
109 candidateWrite.getSource() != transferRead.getSource())
111 transferWrite = candidateWrite;
115 for (
auto operand : transferRead.getOperands())
116 if (!loop.isDefinedOutsideOfLoop(operand))
121 if (!transferWrite) {
125 loop.moveOutOfLoop(transferRead);
129 LLVM_DEBUG(
DBGS() <<
"Candidate: " << *transferWrite.getOperation()
140 if (transferRead.getIndices() != transferWrite.getIndices() ||
141 transferRead.getVectorType() != transferWrite.getVectorType() ||
142 transferRead.getPermutationMap() != transferWrite.getPermutationMap())
145 auto *source = transferRead.getSource().getDefiningOp();
146 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
149 source = transferWrite.getSource().getDefiningOp();
150 if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
158 for (
auto &use : transferRead.getSource().getUses()) {
159 if (!loop->isAncestor(use.getOwner()))
161 if (use.getOwner() == transferRead.getOperation() ||
162 use.getOwner() == transferWrite.getOperation())
164 if (
auto transferWriteUse =
165 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
167 cast<VectorTransferOpInterface>(*transferWrite),
168 cast<VectorTransferOpInterface>(*transferWriteUse),
171 }
else if (
auto transferReadUse =
172 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
174 cast<VectorTransferOpInterface>(*transferWrite),
175 cast<VectorTransferOpInterface>(*transferReadUse),
186 loop.moveOutOfLoop(transferRead);
189 transferWrite->moveAfter(loop);
192 IRRewriter rewriter(transferRead.getContext());
198 auto maybeNewLoop = loop.replaceWithAdditionalYields(
199 rewriter, transferRead.getVector(),
204 transferWrite.getVectorMutable().assign(
205 maybeNewLoop->getOperation()->getResults().back());
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
void hoistRedundantVectorTransfers(func::FuncOp func)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
Include the generated interface declarations.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.