31 #include "llvm/ADT/StringRef.h" 32 #include "llvm/Support/Debug.h" 36 #define DEBUG_TYPE "linalg-hoisting" 38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 46 struct HoistableWrite {
47 vector::TransferWriteOp transferWriteOp;
48 tensor::InsertSliceOp insertSliceOp;
52 struct HoistableRead {
53 vector::TransferReadOp transferReadOp;
54 tensor::ExtractSliceOp extractSliceOp;
64 attr = ofr.get<
Value>().getDefiningOp<arith::ConstantOp>().getValue();
66 return intAttr.getValue().getSExtValue();
70 if (cst1 && cst2 && *cst1 == *cst2)
72 auto v1 = op1.dyn_cast<
Value>(), v2 = op2.dyn_cast<
Value>();
73 return v1 && v2 && v1 == v2;
78 tensor::InsertSliceOp si) {
79 if (s.getStaticOffsets().size() != si.getStaticOffsets().size())
81 if (s.getStaticSizes().size() != si.getStaticSizes().size())
83 if (s.getStaticStrides().size() != si.getStaticStrides().size())
85 for (
auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
88 for (
auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
91 for (
auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
101 assert(write.transferWriteOp &&
102 "expected hoistable write to have a .transfer_write");
104 LLVM_DEBUG(
DBGS() <<
"findMatchingTransferRead for: " 105 << *write.transferWriteOp.getOperation() <<
"\n");
106 if (write.insertSliceOp)
107 LLVM_DEBUG(
DBGS() <<
"findMatchingTransferRead inserSliceOp: " 108 << *write.insertSliceOp.getOperation() <<
"\n");
111 LLVM_DEBUG(
DBGS() <<
"findMatchingTransferRead inspect user: " << *user
116 tensor::ExtractSliceOp sliceOp;
118 if (write.insertSliceOp) {
119 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
120 if (!sliceOp || sliceOp.getResult().getType() !=
121 write.insertSliceOp.getSource().getType())
124 LLVM_DEBUG(
DBGS() <<
"check whether sameOffsetsSizesAndStrides: " 125 << *sliceOp <<
" vs " << *write.insertSliceOp <<
"\n");
129 LLVM_DEBUG(
DBGS() <<
"sameOffsetsSizesAndStrides: SUCCESS\n");
137 if (u == write.transferWriteOp)
145 if (skip || !otherUser)
147 maybeTransferReadUser = otherUser;
150 LLVM_DEBUG(
DBGS() <<
"maybeTransferReadUser: " << *maybeTransferReadUser
152 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
153 if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
154 read.getVectorType() == write.transferWriteOp.getVectorType())
155 return HoistableRead{read, sliceOp};
157 return HoistableRead();
163 HoistableRead candidateRead,
168 uses.push_back(tensorArg.
getUses());
169 while (!uses.empty()) {
170 for (
OpOperand &use : uses.pop_back_val()) {
173 if (user == candidateRead.transferReadOp ||
174 user == candidateRead.extractSliceOp ||
175 user == write.transferWriteOp || user == write.insertSliceOp)
180 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
183 if (
auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
184 uses.push_back(writeUser->getResult(0).getUses());
190 if (
auto forUser = dyn_cast<scf::ForOp>(user)) {
191 Value arg = forUser.getLoopBody().getArgument(
192 use.getOperandNumber() - forUser.getNumControlOperands() +
199 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
200 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
201 yieldUser->getParentOp())) {
202 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
206 auto read = dyn_cast<vector::TransferReadOp>(user);
208 cast<VectorTransferOpInterface>(read.getOperation()),
209 cast<VectorTransferOpInterface>(
210 write.transferWriteOp.getOperation()))) {
222 static HoistableWrite
226 if (
auto write = v.
getDefiningOp<vector::TransferWriteOp>()) {
228 for (
Value operand : write.getIndices())
229 if (!forOp.isDefinedOutsideOfLoop(operand))
230 return HoistableWrite();
232 return HoistableWrite{write,
nullptr};
235 if (
auto insertSliceOp = v.
getDefiningOp<tensor::InsertSliceOp>()) {
238 insertSliceOp.getSource().
getDefiningOp<vector::TransferWriteOp>();
240 return HoistableWrite();
243 auto bbArg = insertSliceOp.getDest().dyn_cast<
BlockArgument>();
244 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
246 return HoistableWrite();
249 for (
Value operand : insertSliceOp->getOperands().drop_front(
250 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
251 if (!forOp.isDefinedOutsideOfLoop(operand))
252 return HoistableWrite();
254 return HoistableWrite{write, insertSliceOp};
257 return HoistableWrite();
264 assert(read.transferReadOp && write.transferWriteOp &&
265 "expected transfer_read and transfer_write ops to be set");
266 assert(((read.extractSliceOp && write.insertSliceOp) ||
267 (!read.extractSliceOp && !write.insertSliceOp)) &&
268 "expected matching extract_slice / insert_slice");
269 LLVM_DEBUG(
DBGS() <<
"In forOp:\n" 270 << *forOp.getOperation()
271 <<
"\nHoist: " << *read.transferReadOp.getOperation()
272 <<
"\nHoist: " << *write.transferWriteOp.getOperation()
273 <<
"\nInvolving: " << tensorBBArg <<
"\n");
276 if (read.extractSliceOp)
277 forOp.moveOutOfLoop(read.extractSliceOp);
280 forOp.moveOutOfLoop(read.transferReadOp);
284 unsigned initArgNumber = tensorBBArg.
getArgNumber() - 1;
287 if (read.extractSliceOp)
288 read.extractSliceOp.getSourceMutable().assign(
289 forOp.getInitArgs()[initArgNumber]);
291 read.transferReadOp.getSourceMutable().assign(
292 forOp.getInitArgs()[initArgNumber]);
295 if (write.insertSliceOp)
296 write.insertSliceOp->moveAfter(forOp);
297 write.transferWriteOp->moveAfter(forOp);
300 auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
301 if (write.insertSliceOp)
302 yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest());
304 yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource());
313 b, forOp, read.transferReadOp.getVector(), yieldFn);
320 if (write.insertSliceOp) {
321 newForOp.getResult(initArgNumber)
322 .replaceAllUsesWith(write.insertSliceOp.getResult());
323 write.transferWriteOp.getSourceMutable().assign(
324 read.extractSliceOp.getResult());
325 write.insertSliceOp.getDestMutable().assign(
326 read.extractSliceOp.getSource());
328 newForOp.getResult(initArgNumber)
329 .replaceAllUsesWith(write.transferWriteOp.getResult());
330 write.transferWriteOp.getSourceMutable().assign(
331 newForOp.getResult(initArgNumber));
335 write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
353 func.walk([&](scf::ForOp forOp) {
354 Operation *yield = forOp.getBody()->getTerminator();
357 HoistableWrite write =
359 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
361 LLVM_DEBUG(dbgs() <<
"\n";
362 DBGS() <<
"Candidate write for hoisting: " 363 << *write.transferWriteOp.getOperation() <<
"\n");
364 if (write.insertSliceOp)
365 LLVM_DEBUG(
DBGS() <<
"Candidate insert_slice for hoisting: " 366 << *write.insertSliceOp.getOperation() <<
"\n");
367 if (llvm::any_of(write.transferWriteOp.getIndices(),
368 [&forOp](
Value index) {
369 return !forOp.isDefinedOutsideOfLoop(index);
373 HoistableRead matchingRead =
377 if (!matchingRead.transferReadOp ||
381 LLVM_DEBUG(
DBGS() <<
"Start hoisting\n");
395 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
410 func.walk([&](vector::TransferReadOp transferRead) {
411 if (!transferRead.getShapedType().isa<MemRefType>())
414 LLVM_DEBUG(
DBGS() <<
"Candidate for hoisting: " 415 << *transferRead.getOperation() <<
"\n");
416 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
417 LLVM_DEBUG(
DBGS() <<
"Parent op: " << *transferRead->getParentOp()
422 LLVM_DEBUG(
DBGS() <<
"Candidate read: " << *transferRead.getOperation()
430 vector::TransferWriteOp transferWrite;
431 for (
auto *sliceOp : llvm::reverse(forwardSlice)) {
432 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
433 if (!candidateWrite ||
434 candidateWrite.getSource() != transferRead.getSource())
436 transferWrite = candidateWrite;
440 for (
auto operand : transferRead.getOperands())
441 if (!loop.isDefinedOutsideOfLoop(operand))
448 LLVM_DEBUG(
DBGS() <<
"Candidate: " << *transferWrite.getOperation()
456 if (transferRead.getIndices() != transferWrite.getIndices() &&
457 transferRead.getVectorType() == transferWrite.getVectorType())
465 for (
auto &use : transferRead.getSource().getUses()) {
466 if (!loop->isAncestor(use.getOwner()))
468 if (use.getOwner() == transferRead.getOperation() ||
469 use.getOwner() == transferWrite.getOperation())
471 if (
auto transferWriteUse =
472 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
474 cast<VectorTransferOpInterface>(transferWrite.getOperation()),
475 cast<VectorTransferOpInterface>(
476 transferWriteUse.getOperation())))
478 }
else if (
auto transferReadUse =
479 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
481 cast<VectorTransferOpInterface>(transferWrite.getOperation()),
482 cast<VectorTransferOpInterface>(
483 transferReadUse.getOperation())))
493 loop.moveOutOfLoop(transferRead);
496 transferWrite->moveAfter(loop);
509 transferWrite.getVectorMutable().assign(newForOp.getResults().back());
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Include the generated interface declarations.
U dyn_cast_or_null() const
Operation is a basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class represents a single result from folding an operation.
A class for computing basic dominance information.
unsigned getArgNumber() const
Returns the number of this argument.
user_range getUsers() const
static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2)
Return true if op1 and op2 are the same constant or the same SSA value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Block * getOwner() const
Returns the block that owns this argument.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Attributes are known-constant values of operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void getForwardSlice(Operation *op, SetVector< Operation *> *forwardSlice, TransitiveFilter filter=nullptr)
Fills forwardSlice with the computed forward slice (i.e.
static HoistableRead findMatchingTransferRead(HoistableWrite write, Value srcTensor)
Look for a HoistableRead, in the given tensor uses, accessing the same offset as the HoistableWrite...
static WalkResult advance()
scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, const NewYieldValueFn &newYieldValuesFn)
IRValueT get() const
Return the current value being used by this operand.
static WalkResult interrupt()
This class represents an argument of a Block.
static HoistableWrite getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, OpOperand &yieldOperand)
Return the forOp-invariant HoistableWrite that produces yieldOperand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void hoistRedundantVectorTransfers(func::FuncOp func)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
OpOperand & getOpOperand(unsigned idx)
static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, tensor::InsertSliceOp si)
Return true is all offsets, sizes and strides are equal.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents an operand of an operation.
static void hoistReadWrite(HoistableRead read, HoistableWrite write, BlockArgument tensorBBArg)
Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn't require the operations to have the same tensor/mem...
size_t moveLoopInvariantCode(RegionRange regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
void hoistRedundantVectorTransfersOnTensor(func::FuncOp func)
Same behavior as hoistRedundantVectorTransfers but works on tensors instead of buffers.
user_range getUsers()
Returns a range of all users.
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBBArgs)> NewYieldValueFn
Replace the loop with newIterOperands added as new initialization values.
static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, HoistableRead candidateRead, BlockArgument tensorArg)
Check if the chunk of data inserted by the HoistableWrite are read by any other op than the Hoistable...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.