25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ErrorHandling.h"
28 #define DEBUG_TYPE "subset-hoisting"
30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
39 vector::TransferWriteOp transferWriteOp) {
40 for (
Value operand : transferWriteOp.getIndices())
41 if (!forOp.isDefinedOutsideOfLoop(operand))
49 tensor::InsertSliceOp insertSliceOp) {
50 for (
Value operand : insertSliceOp->getOperands().drop_front(
51 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
52 if (!forOp.isDefinedOutsideOfLoop(operand))
66 tensor::InsertSliceOp insertSliceOp,
68 assert(isa<RankedTensorType>(srcTensor.
getType()) &&
"not a ranked tensor");
72 LLVM_DEBUG(
DBGS() <<
"--find matching read for: " << insertSliceOp <<
"\n";
73 DBGS() <<
"--amongst users of: " << srcTensor <<
"\n");
76 if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest()))
77 llvm::append_range(users, insertSliceOp.getDest().getUsers());
80 LLVM_DEBUG(
DBGS() <<
"----inspect user: " << *user <<
"\n");
81 auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
86 if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() ||
87 !extractSliceOp.isSameAs(insertSliceOp, isSame)) {
88 LLVM_DEBUG(
DBGS() <<
"------not a matching extract_slice\n";
89 DBGS() << *user <<
" vs " << *insertSliceOp <<
"\n");
95 if (!isa<BlockArgument>(extractSliceOp.getSource()) &&
96 !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
97 LLVM_DEBUG(
DBGS() <<
"------transfer_read vector is loop-dependent\n");
100 return extractSliceOp;
107 LLVM_DEBUG(
DBGS() <<
"----no matching extract_slice");
120 vector::TransferWriteOp transferWriteOp,
122 if (!isa<RankedTensorType>(srcTensor.
getType()))
127 LLVM_DEBUG(
DBGS() <<
"--find matching read for: " << transferWriteOp <<
"\n";
128 DBGS() <<
"--amongst users of: " << srcTensor <<
"\n";);
136 if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource()))
137 llvm::append_range(users, transferWriteOp.getSource().getUsers());
138 while (!users.empty()) {
140 LLVM_DEBUG(
DBGS() <<
"----inspect user: " << *user <<
"\n");
141 auto read = dyn_cast<vector::TransferReadOp>(user);
144 if (read.getIndices() != transferWriteOp.getIndices() ||
145 read.getVectorType() != transferWriteOp.getVectorType()) {
146 LLVM_DEBUG(
DBGS() <<
"------not a transfer_read that matches the "
148 << *user <<
"\n\t(vs " << *transferWriteOp <<
")\n");
155 if (!isa<BlockArgument>(read.getSource()) &&
156 !forOp.isDefinedOutsideOfLoop(read.getSource())) {
157 LLVM_DEBUG(
DBGS() <<
"------transfer_read vector appears loop "
158 "dependent but will be tested for disjointness as "
159 "part of the bypass analysis\n");
161 LLVM_DEBUG(
DBGS() <<
"------found match\n");
167 if (isa<vector::TransferWriteOp>(user)) {
172 cast<VectorTransferOpInterface>(user),
173 cast<VectorTransferOpInterface>(
174 transferWriteOp.getOperation()))) {
175 LLVM_DEBUG(
DBGS() <<
"----follow through disjoint write\n");
178 LLVM_DEBUG(
DBGS() <<
"----skip non-disjoint write\n");
183 LLVM_DEBUG(
DBGS() <<
"--no matching transfer_read\n");
185 "no matching transfer_read");
203 "bbArg and yieldOperand must match");
204 assert(isa<scf::YieldOp>(yieldOperand.
getOwner()) &&
"must be an scf.yield");
207 auto transferWriteOp = v.
getDefiningOp<vector::TransferWriteOp>();
208 if (!transferWriteOp)
211 if (transferWriteOp->getNumResults() == 0) {
213 "unsupported transfer_write on buffers");
223 v.
getLoc(),
"transfer_write indexing is loop-dependent");
225 return transferWriteOp;
247 "bbArg and yieldOperand must match");
248 assert(isa<scf::YieldOp>(yieldOperand.
getOwner()) &&
"must be an scf.yield");
251 auto insertSliceOp = v.
getDefiningOp<tensor::InsertSliceOp>();
258 if (bbArg != insertSliceOp.getDest())
264 v.
getLoc(),
"insert_slice indexing is loop-dependent");
266 return insertSliceOp;
280 uses.push_back(tensorArg.
getUses());
281 while (!uses.empty()) {
282 for (
OpOperand &use : uses.pop_back_val()) {
285 if (user == candidateReadOp || user == writeOp)
291 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
295 if (isa<vector::TransferWriteOp>(writeOp)) {
296 if (
auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
297 uses.push_back(writeUser->getResult(0).getUses());
305 if (
auto forUser = dyn_cast<scf::ForOp>(user)) {
306 Value arg = forUser.getBody()->getArgument(
307 use.getOperandNumber() - forUser.getNumControlOperands() +
314 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
317 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
324 if (isa<vector::TransferWriteOp>(writeOp)) {
325 auto read = dyn_cast<vector::TransferReadOp>(user);
327 cast<VectorTransferOpInterface>(read.getOperation()),
328 cast<VectorTransferOpInterface>(writeOp))) {
341 RewriterBase &rewriter, vector::TransferReadOp transferReadOp,
342 vector::TransferWriteOp transferWriteOp,
BlockArgument tensorBBArg) {
344 LLVM_DEBUG(
DBGS() <<
"--Start hoisting\n";
345 DBGS() <<
"--Hoist read : " << transferReadOp <<
"\n";
346 DBGS() <<
"--Hoist write: " << transferWriteOp <<
"\n";
347 DBGS() <<
"--Involving : " << tensorBBArg <<
"\n");
356 transferReadOp->moveBefore(forOp);
357 if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) {
359 transferReadOp.getSourceMutable().assign(
360 forOp.getInitArgs()[initArgNumber]);
370 auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
371 rewriter, {transferReadOp.getVector()},
376 cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
379 yieldOp->setOperand(initArgNumber, transferWriteOp.getSource());
385 transferWriteOp->moveAfter(newForOp);
387 transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
389 transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber));
392 transferWriteOp.getResult(), transferWriteOp);
400 tensor::ExtractSliceOp extractSliceOp,
401 tensor::InsertSliceOp insertSliceOp,
404 LLVM_DEBUG(
DBGS() <<
"--Start hoisting\n";
405 DBGS() <<
"--Hoist read : " << extractSliceOp <<
"\n";
406 DBGS() <<
"--Hoist write: " << insertSliceOp <<
"\n";
407 DBGS() <<
"--Involving : " << tensorBBArg <<
"\n");
416 extractSliceOp->moveBefore(forOp);
417 if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
418 assert(extractSliceOp.getSource() == tensorBBArg &&
419 "extractSlice source not defined above must be the tracked bbArg");
421 extractSliceOp.getSourceMutable().assign(
422 forOp.getInitArgs()[initArgNumber]);
432 auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
433 rewriter, extractSliceOp.getResult(),
438 cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
441 yieldOp->setOperand(initArgNumber, insertSliceOp.getDest());
447 insertSliceOp->moveAfter(newForOp);
449 insertSliceOp.getSourceMutable().assign(newForOp.getResults().back());
450 insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber));
453 insertSliceOp.getResult(), insertSliceOp);
464 LLVM_DEBUG(
DBGS() <<
"Enter hoistRedundantSubsetExtractInsert scf.for\n");
465 Operation *yield = forOp.getBody()->getTerminator();
467 LLVM_DEBUG(
DBGS() <<
"\n";
DBGS() <<
"Consider " << forOp <<
"\n");
469 scf::ForOp newForOp = forOp;
473 LLVM_DEBUG(
DBGS() <<
"Consider " << it.value() <<
"\n");
485 LLVM_DEBUG(
DBGS() <<
"no loop invariant write defining iter_args "
486 << it.value() <<
"\n");
491 ? transferWriteOp->getOperation()
492 : insertSliceOp->getOperation();
496 LLVM_DEBUG(
DBGS() <<
"write with more than 1 use " << *writeOp <<
"\n");
500 LLVM_DEBUG(
DBGS() <<
"Write to hoist: " << *writeOp <<
"\n");
507 rewriter, *transferWriteOp, it.value());
509 matchingReadOp = maybeTransferRead->getOperation();
512 rewriter, *insertSliceOp, it.value());
514 matchingReadOp = maybeExtractSlice->getOperation();
516 llvm_unreachable(
"unexpected case");
518 if (!matchingReadOp) {
519 LLVM_DEBUG(
DBGS() <<
"No matching read\n");
528 if (maybeUnknownOp) {
529 LLVM_DEBUG(
DBGS() <<
"Tensor chunk accessed by unknown op, skip: "
530 << *maybeUnknownOp <<
"\n");
536 LLVM_DEBUG(
DBGS() <<
"Read to hoist: " << *matchingReadOp <<
"\n");
539 rewriter, cast<vector::TransferReadOp>(matchingReadOp),
540 *transferWriteOp, it.value());
543 rewriter, cast<tensor::ExtractSliceOp>(matchingReadOp),
544 *insertSliceOp, it.value());
546 llvm_unreachable(
"unexpected case");
550 }
while (forOp != newForOp);
static FailureOr< vector::TransferWriteOp > getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp, BlockArgument bbArg, OpOperand &yieldOperand)
Return the vector.transfer_write that produces yieldOperand, if:
static FailureOr< vector::TransferReadOp > findHoistableMatchingTransferRead(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, BlockArgument srcTensor)
Given an srcTensor that is a block argument belong to a loop.
static bool isSubsetLocationLoopInvariant(scf::ForOp forOp, vector::TransferWriteOp transferWriteOp)
Return true if the location of the subset defined by the op is invariant of the loop iteration.
static scf::ForOp hoistTransferReadWrite(RewriterBase &rewriter, vector::TransferReadOp transferReadOp, vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg)
Mechanical hoisting of a matching read / write pair.
static FailureOr< tensor::InsertSliceOp > getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp, BlockArgument bbArg, OpOperand &yieldOperand)
Return the tensor.insert_slice that produces yieldOperand, if:
static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter, tensor::ExtractSliceOp extractSliceOp, tensor::InsertSliceOp insertSliceOp, BlockArgument tensorBBArg)
Mechanical hoisting of a matching read / write pair.
static FailureOr< tensor::ExtractSliceOp > findHoistableMatchingExtractSlice(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp, BlockArgument srcTensor)
Given an srcTensor that is a block argument belong to a loop.
static Operation * isTensorChunkAccessedByUnknownOp(Operation *writeOp, Operation *candidateReadOp, BlockArgument tensorArg)
Check if the chunk of data inserted by the writeOp is read by any other op than the candidateReadOp.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
OpOperand & getOpOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter, scf::ForOp forOp)
Greedily hoist redundant subset extract/insert operations on tensors outside of forOp.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn't require the operations to have the same tensor/mem...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.