25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "vector-transfer-opt"
31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
46 class TransferOptimization {
49 : rewriter(rewriter), dominators(op), postDominators(op) {}
50 void deadStoreOp(vector::TransferWriteOp);
51 void storeToLoadForwarding(vector::TransferReadOp);
63 std::vector<Operation *> opToErase;
70 "This function only works for ops i the same region");
72 if (dominators.dominates(start, dest))
79 while (!worklist.empty()) {
80 Block *bb = worklist.pop_back_val();
81 if (!visited.insert(bb).second)
83 if (dominators.dominates(bb, destBlock))
101 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
102 LLVM_DEBUG(
DBGS() <<
"Candidate for dead store: " << *write.getOperation()
105 Operation *firstOverwriteCandidate =
nullptr;
106 Value source = write.getSource();
108 while (
auto subView = source.
getDefiningOp<memref::SubViewOp>())
109 source = subView.getSource();
112 llvm::SmallDenseSet<Operation *, 32> processed;
113 while (!users.empty()) {
116 if (!processed.insert(user).second)
118 if (
auto subView = dyn_cast<memref::SubViewOp>(user)) {
119 users.append(subView->getUsers().begin(), subView->getUsers().end());
124 if (user == write.getOperation())
126 if (
auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128 if (write.getSource() == nextWrite.getSource() &&
130 postDominators.postDominates(nextWrite, write)) {
131 if (firstOverwriteCandidate ==
nullptr ||
132 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
133 firstOverwriteCandidate = nextWrite;
136 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
140 if (
auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
143 cast<VectorTransferOpInterface>(write.getOperation()),
144 cast<VectorTransferOpInterface>(transferOp.getOperation())))
147 blockingAccesses.push_back(user);
149 if (firstOverwriteCandidate ==
nullptr)
153 assert(writeAncestor &&
154 "write op should be recursively part of the top region");
156 for (
Operation *access : blockingAccesses) {
160 if (accessAncestor ==
nullptr ||
161 !isReachable(writeAncestor, accessAncestor))
163 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
164 LLVM_DEBUG(
DBGS() <<
"Store may not be dead due to op: "
165 << *accessAncestor <<
"\n");
169 LLVM_DEBUG(
DBGS() <<
"Found dead store: " << *write.getOperation()
170 <<
" overwritten by: " << *firstOverwriteCandidate <<
"\n");
171 opToErase.push_back(write.getOperation());
185 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
186 if (read.hasOutOfBoundsDim())
188 LLVM_DEBUG(
DBGS() <<
"Candidate for Forwarding: " << *read.getOperation()
191 vector::TransferWriteOp lastwrite =
nullptr;
192 Value source = read.getSource();
194 while (
auto subView = source.
getDefiningOp<memref::SubViewOp>())
195 source = subView.getSource();
198 llvm::SmallDenseSet<Operation *, 32> processed;
199 while (!users.empty()) {
202 if (!processed.insert(user).second)
204 if (
auto subView = dyn_cast<memref::SubViewOp>(user)) {
205 users.append(subView->getUsers().begin(), subView->getUsers().end());
210 if (
auto write = dyn_cast<vector::TransferWriteOp>(user)) {
214 cast<VectorTransferOpInterface>(write.getOperation()),
215 cast<VectorTransferOpInterface>(read.getOperation())))
217 if (write.getSource() == read.getSource() &&
219 if (lastwrite ==
nullptr || dominators.dominates(lastwrite, write))
222 assert(dominators.dominates(write, lastwrite));
226 blockingWrites.push_back(user);
229 if (lastwrite ==
nullptr)
234 assert(readAncestor &&
235 "read op should be recursively part of the top region");
237 for (
Operation *write : blockingWrites) {
241 if (writeAncestor ==
nullptr || !isReachable(writeAncestor, readAncestor))
243 if (!postDominators.postDominates(lastwrite, write)) {
244 LLVM_DEBUG(
DBGS() <<
"Fail to do write to read forwarding due to op: "
250 LLVM_DEBUG(
DBGS() <<
"Forward value from " << *lastwrite.getOperation()
251 <<
" to: " << *read.getOperation() <<
"\n");
252 read.replaceAllUsesWith(lastwrite.getVector());
253 opToErase.push_back(read.getOperation());
261 llvm::make_filter_range(sizes, [](int64_t sz) {
return sz != 1; }));
262 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
263 targetShape, inputType, offsets, sizes, strides);
272 MemRefType inputType = input.
getType().
cast<MemRefType>();
273 assert(inputType.hasStaticShape());
277 MemRefType resultType =
278 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
282 return rewriter.
create<memref::SubViewOp>(
283 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
288 return llvm::count_if(shape, [](int64_t dimSize) {
return dimSize != 1; });
294 return cst && cst.value() == 0;
299 class TransferReadDropUnitDimsPattern
303 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
305 auto loc = transferReadOp.getLoc();
306 Value vector = transferReadOp.getVector();
307 VectorType vectorType = vector.
getType().
cast<VectorType>();
308 Value source = transferReadOp.getSource();
311 if (!sourceType || !sourceType.hasStaticShape())
313 if (sourceType.getNumElements() != vectorType.getNumElements())
316 if (transferReadOp.hasOutOfBoundsDim())
318 if (!transferReadOp.getPermutationMap().isMinorIdentity())
320 int reducedRank = getReducedRank(sourceType.getShape());
321 if (reducedRank == sourceType.getRank())
323 if (reducedRank != vectorType.getRank())
326 if (llvm::any_of(transferReadOp.getIndices(),
327 [](
Value v) { return !isZero(v); }))
329 Value reducedShapeSource =
330 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
331 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
335 transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
342 class TransferWriteDropUnitDimsPattern
346 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
348 auto loc = transferWriteOp.getLoc();
349 Value vector = transferWriteOp.getVector();
350 VectorType vectorType = vector.
getType().
cast<VectorType>();
351 Value source = transferWriteOp.getSource();
354 if (!sourceType || !sourceType.hasStaticShape())
356 if (sourceType.getNumElements() != vectorType.getNumElements())
359 if (transferWriteOp.hasOutOfBoundsDim())
361 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
363 int reducedRank = getReducedRank(sourceType.getShape());
364 if (reducedRank == sourceType.getRank())
366 if (reducedRank != vectorType.getRank())
369 if (llvm::any_of(transferWriteOp.getIndices(),
370 [](
Value v) { return !isZero(v); }))
372 Value reducedShapeSource =
373 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
374 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
378 transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
385 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
387 auto shape = memrefType.getShape();
392 if (strides.back() != 1)
396 for (
auto [targetDim, memrefDim, memrefStride] :
397 llvm::reverse(llvm::zip(targetShape, shape, strides))) {
398 flatDim *= memrefDim;
399 if (flatDim != memrefStride || targetDim != memrefDim)
408 Value input, int64_t firstDimToCollapse) {
409 ShapedType inputType = input.
getType().
cast<ShapedType>();
410 if (inputType.getRank() == 1)
413 for (int64_t i = 0; i < firstDimToCollapse; ++i)
416 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
417 collapsedIndices.push_back(i);
418 reassociation.push_back(collapsedIndices);
419 return rewriter.
create<memref::CollapseShapeOp>(loc, input, reassociation);
426 checkAndCollapseInnerZeroIndices(
ValueRange indices, int64_t firstDimToCollapse,
428 int64_t rank = indices.size();
429 if (firstDimToCollapse >= rank)
431 for (int64_t i = firstDimToCollapse; i < rank; ++i) {
432 arith::ConstantIndexOp cst =
433 indices[i].getDefiningOp<arith::ConstantIndexOp>();
434 if (!cst || cst.value() != 0)
437 outIndices = indices;
438 outIndices.resize(firstDimToCollapse + 1);
446 class FlattenContiguousRowMajorTransferReadPattern
450 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
452 auto loc = transferReadOp.getLoc();
453 Value vector = transferReadOp.getVector();
454 VectorType vectorType = vector.
getType().
cast<VectorType>();
455 Value source = transferReadOp.getSource();
460 if (vectorType.getRank() <= 1)
463 if (!hasMatchingInnerContigousShape(
465 vectorType.getShape().take_back(vectorType.getRank() - 1)))
467 int64_t firstContiguousInnerDim =
468 sourceType.getRank() - vectorType.getRank();
470 if (transferReadOp.hasOutOfBoundsDim())
472 if (!transferReadOp.getPermutationMap().isMinorIdentity())
474 if (transferReadOp.getMask())
477 if (
failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
478 firstContiguousInnerDim,
481 Value collapsedSource =
482 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
483 MemRefType collapsedSourceType =
485 int64_t collapsedRank = collapsedSourceType.getRank();
486 assert(collapsedRank == firstContiguousInnerDim + 1);
491 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
492 vectorType.getElementType());
493 vector::TransferReadOp flatRead = rewriter.
create<vector::TransferReadOp>(
494 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
497 transferReadOp, vector.
getType().
cast<VectorType>(), flatRead);
506 class FlattenContiguousRowMajorTransferWritePattern
510 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
512 auto loc = transferWriteOp.getLoc();
513 Value vector = transferWriteOp.getVector();
514 VectorType vectorType = vector.
getType().
cast<VectorType>();
515 Value source = transferWriteOp.getSource();
520 if (vectorType.getRank() <= 1)
523 if (!hasMatchingInnerContigousShape(
525 vectorType.getShape().take_back(vectorType.getRank() - 1)))
527 int64_t firstContiguousInnerDim =
528 sourceType.getRank() - vectorType.getRank();
530 if (transferWriteOp.hasOutOfBoundsDim())
532 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
534 if (transferWriteOp.getMask())
537 if (
failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
538 firstContiguousInnerDim,
541 Value collapsedSource =
542 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
543 MemRefType collapsedSourceType =
545 int64_t collapsedRank = collapsedSourceType.getRank();
546 assert(collapsedRank == firstContiguousInnerDim + 1);
551 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
552 vectorType.getElementType());
554 rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, vector);
555 vector::TransferWriteOp flatWrite =
556 rewriter.
create<vector::TransferWriteOp>(
557 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
559 rewriter.
eraseOp(transferWriteOp);
573 class RewriteScalarExtractElementOfTransferRead
579 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
584 if (!extractOp.getVector().hasOneUse())
587 if (xferOp.getMask())
590 if (!xferOp.getPermutationMap().isMinorIdentity())
594 if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
598 xferOp.getIndices().end());
599 if (extractOp.getPosition()) {
603 rewriter, extractOp.getLoc(), sym0 + sym1,
604 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
605 if (ofr.is<
Value>()) {
606 newIndices[newIndices.size() - 1] = ofr.get<
Value>();
608 newIndices[newIndices.size() - 1] =
609 rewriter.
create<arith::ConstantIndexOp>(extractOp.getLoc(),
613 if (xferOp.getSource().getType().isa<MemRefType>()) {
618 extractOp, xferOp.getSource(), newIndices);
633 class RewriteScalarExtractOfTransferRead
640 if (extractOp.getType().isa<VectorType>())
642 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
647 if (!extractOp.getVector().hasOneUse())
650 if (xferOp.getMask())
653 if (!xferOp.getPermutationMap().isMinorIdentity())
657 if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
661 xferOp.getIndices().end());
663 int64_t offset = it.value().cast<IntegerAttr>().getInt();
665 newIndices.size() - extractOp.getPosition().size() + it.index();
667 rewriter, extractOp.getLoc(),
669 if (ofr.is<
Value>()) {
670 newIndices[idx] = ofr.get<
Value>();
672 newIndices[idx] = rewriter.
create<arith::ConstantIndexOp>(
676 if (xferOp.getSource().getType().isa<MemRefType>()) {
681 extractOp, xferOp.getSource(), newIndices);
689 class RewriteScalarWrite :
public OpRewritePattern<vector::TransferWriteOp> {
695 auto vecType = xferOp.getVectorType();
696 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
699 if (xferOp.getMask())
702 if (!xferOp.getPermutationMap().isMinorIdentity())
706 if (vecType.getRank() == 0) {
709 scalar = rewriter.
create<vector::ExtractElementOp>(xferOp.getLoc(),
713 scalar = rewriter.
create<vector::ExtractOp>(xferOp.getLoc(),
714 xferOp.getVector(), pos);
717 if (xferOp.getSource().getType().isa<MemRefType>()) {
719 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
722 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
731 TransferOptimization opt(rewriter, rootOp);
734 rootOp->
walk([&](vector::TransferReadOp read) {
735 if (read.getShapedType().isa<MemRefType>())
736 opt.storeToLoadForwarding(read);
739 rootOp->
walk([&](vector::TransferWriteOp write) {
740 if (write.getShapedType().isa<MemRefType>())
741 opt.deadStoreOp(write);
748 patterns.
add<RewriteScalarExtractElementOfTransferRead,
749 RewriteScalarExtractOfTransferRead, RewriteScalarWrite>(
756 .
add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
763 patterns.
add<FlattenContiguousRowMajorTransferReadPattern,
764 FlattenContiguousRowMajorTransferWritePattern>(
static bool isZero(OpFoldResult v)
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
succ_iterator succ_begin()
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineSymbolExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
A class for computing basic dominance information.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region * getParentRegion()
Returns the region to which the instruction belongs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A class for computing basic postdominance information.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collects patterns that lower scalar vector transfer ops to memref loads and stores when beneficial.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
MemRefType canonicalizeStridedLayout(MemRefType t)
Return a version of t with identity layout if it can be determined statically that the layout is the ...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...