26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/Debug.h"
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
48template <
typename LoadOrStoreOpTy>
50 return op.getMemref();
58 return op.getSrcMemref();
74 return op.getSrcMemref();
78 return op.getDstMemref();
87template <
typename OpTy>
90 using OpRewritePattern<OpTy>::OpRewritePattern;
92 LogicalResult matchAndRewrite(OpTy loadOp,
93 PatternRewriter &rewriter)
const override;
97template <
typename OpTy>
100 using OpRewritePattern<OpTy>::OpRewritePattern;
102 LogicalResult matchAndRewrite(OpTy loadOp,
103 PatternRewriter &rewriter)
const override;
107template <
typename OpTy>
110 using OpRewritePattern<OpTy>::OpRewritePattern;
112 LogicalResult matchAndRewrite(OpTy loadOp,
113 PatternRewriter &rewriter)
const override;
117template <
typename OpTy>
120 using OpRewritePattern<OpTy>::OpRewritePattern;
122 LogicalResult matchAndRewrite(OpTy storeOp,
123 PatternRewriter &rewriter)
const override;
127template <
typename OpTy>
130 using OpRewritePattern<OpTy>::OpRewritePattern;
132 LogicalResult matchAndRewrite(OpTy storeOp,
133 PatternRewriter &rewriter)
const override;
137template <
typename OpTy>
140 using OpRewritePattern<OpTy>::OpRewritePattern;
142 LogicalResult matchAndRewrite(OpTy storeOp,
143 PatternRewriter &rewriter)
const override;
149 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
151 LogicalResult matchAndRewrite(memref::SubViewOp subView,
152 PatternRewriter &rewriter)
const override {
153 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
158 if (!subView.hasUnitStride()) {
161 if (!srcSubView.hasUnitStride()) {
166 SmallVector<OpFoldResult> resolvedSizes;
167 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
169 subView.getMixedSizes(), srcDroppedDims,
173 SmallVector<Value> resolvedOffsets;
175 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
176 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
181 subView, subView.getType(), srcSubView.getSource(),
183 srcSubView.getMixedStrides());
191class NVGPUAsyncCopyOpSubViewOpFolder final
194 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
196 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
197 PatternRewriter &rewriter)
const override;
208 for (
unsigned i = 0, e = affineMap.
getNumResults(); i < e; i++) {
210 rewriter, loc, affineMap.
getSubMap({i}), indicesOfr);
211 expandedIndices.push_back(
214 return expandedIndices;
217template <
typename XferOp>
220 memref::SubViewOp subviewOp) {
222 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
223 "must be a vector transfer op");
224 if (xferOp.hasOutOfBoundsDim())
226 if (!subviewOp.hasUnitStride()) {
228 xferOp,
"non-1 stride subview, need to track strides in folded memref");
235 memref::SubViewOp subviewOp) {
240 vector::TransferReadOp readOp,
241 memref::SubViewOp subviewOp) {
246 vector::TransferWriteOp writeOp,
247 memref::SubViewOp subviewOp) {
251template <
typename OpTy>
252LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
260 LogicalResult preconditionResult =
262 if (
failed(preconditionResult))
263 return preconditionResult;
266 loadOp.getIndices().end());
269 if (
auto affineLoadOp =
270 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
271 AffineMap affineMap = affineLoadOp.getAffineMap();
273 affineMap,
indices, loadOp.getLoc(), rewriter);
274 indices.assign(expandedIndices.begin(), expandedIndices.end());
278 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
279 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
indices,
283 .Case([&](affine::AffineLoadOp op) {
285 loadOp, subViewOp.getSource(), sourceIndices);
287 .Case([&](memref::LoadOp op) {
289 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
291 .Case([&](vector::LoadOp op) {
293 op, op.getType(), subViewOp.getSource(), sourceIndices);
295 .Case([&](vector::MaskedLoadOp op) {
297 op, op.getType(), subViewOp.getSource(), sourceIndices,
298 op.getMask(), op.getPassThru());
300 .Case([&](vector::TransferReadOp op) {
302 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
304 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
305 subViewOp.getDroppedDims())),
306 op.getPadding(), op.getMask(), op.getInBoundsAttr());
308 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
310 op, op.getType(), subViewOp.getSource(), sourceIndices,
311 op.getLeadDimension(), op.getTransposeAttr());
313 .Case([&](nvgpu::LdMatrixOp op) {
315 op, op.getType(), subViewOp.getSource(), sourceIndices,
316 op.getTranspose(), op.getNumTiles());
318 .DefaultUnreachable(
"unexpected operation");
322template <
typename OpTy>
323LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
332 loadOp.getIndices().end());
335 if (
auto affineLoadOp =
336 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
337 AffineMap affineMap = affineLoadOp.getAffineMap();
339 affineMap,
indices, loadOp.getLoc(), rewriter);
340 indices.assign(expandedIndices.begin(), expandedIndices.end());
347 loadOp.getLoc(), rewriter, expandShapeOp,
indices, sourceIndices,
348 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
352 .Case([&](affine::AffineLoadOp op) {
354 loadOp, expandShapeOp.getViewSource(), sourceIndices);
357 .Case([&](memref::LoadOp op) {
359 loadOp, expandShapeOp.getViewSource(), sourceIndices,
360 op.getNontemporal());
363 .Case([&](vector::LoadOp op) {
365 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
366 op.getNontemporal());
369 .Case([&](vector::MaskedLoadOp op) {
371 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
372 op.getMask(), op.getPassThru());
375 .Case([&](vector::TransferReadOp op) {
377 if (!op.getPermutationMap().isMinorIdentity())
382 const int64_t sourceRank = sourceIndices.size();
383 const int64_t vectorRank = op.getVectorType().getRank();
384 if (sourceRank < vectorRank)
393 op, op.getVectorType(), expandShapeOp.getViewSource(),
394 sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
398 .DefaultUnreachable(
"unexpected operation");
401template <
typename OpTy>
402LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
405 .template getDefiningOp<memref::CollapseShapeOp>();
407 if (!collapseShapeOp)
411 loadOp.getIndices().end());
414 if (
auto affineLoadOp =
415 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
416 AffineMap affineMap = affineLoadOp.getAffineMap();
418 affineMap,
indices, loadOp.getLoc(), rewriter);
419 indices.assign(expandedIndices.begin(), expandedIndices.end());
422 if (failed(resolveSourceIndicesCollapseShape(
423 loadOp.getLoc(), rewriter, collapseShapeOp,
indices, sourceIndices)))
426 .Case([&](affine::AffineLoadOp op) {
428 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
430 .Case([&](memref::LoadOp op) {
432 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
433 op.getNontemporal());
435 .Case([&](vector::LoadOp op) {
437 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
438 op.getNontemporal());
440 .Case([&](vector::MaskedLoadOp op) {
442 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
443 op.getMask(), op.getPassThru());
445 .DefaultUnreachable(
"unexpected operation");
449template <
typename OpTy>
450LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
458 LogicalResult preconditionResult =
460 if (
failed(preconditionResult))
461 return preconditionResult;
464 storeOp.getIndices().end());
467 if (
auto affineStoreOp =
468 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
469 AffineMap affineMap = affineStoreOp.getAffineMap();
471 affineMap,
indices, storeOp.getLoc(), rewriter);
472 indices.assign(expandedIndices.begin(), expandedIndices.end());
476 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
477 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
indices,
481 .Case([&](affine::AffineStoreOp op) {
483 op, op.getValue(), subViewOp.getSource(), sourceIndices);
485 .Case([&](memref::StoreOp op) {
487 op, op.getValue(), subViewOp.getSource(), sourceIndices,
488 op.getNontemporal());
490 .Case([&](vector::TransferWriteOp op) {
492 op, op.getValue(), subViewOp.getSource(), sourceIndices,
494 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
495 subViewOp.getDroppedDims())),
496 op.getMask(), op.getInBoundsAttr());
498 .Case([&](vector::StoreOp op) {
500 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
502 .Case([&](vector::MaskedStoreOp op) {
504 op, subViewOp.getSource(), sourceIndices, op.getMask(),
505 op.getValueToStore());
507 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
509 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
510 op.getLeadDimension(), op.getTransposeAttr());
512 .DefaultUnreachable(
"unexpected operation");
516template <
typename OpTy>
517LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
526 storeOp.getIndices().end());
529 if (
auto affineStoreOp =
530 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
531 AffineMap affineMap = affineStoreOp.getAffineMap();
533 affineMap,
indices, storeOp.getLoc(), rewriter);
534 indices.assign(expandedIndices.begin(), expandedIndices.end());
541 storeOp.getLoc(), rewriter, expandShapeOp,
indices, sourceIndices,
542 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
545 .Case([&](affine::AffineStoreOp op) {
547 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
550 .Case([&](memref::StoreOp op) {
552 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
553 sourceIndices, op.getNontemporal());
555 .Case([&](vector::StoreOp op) {
557 op, op.getValueToStore(), expandShapeOp.getViewSource(),
558 sourceIndices, op.getNontemporal());
560 .Case([&](vector::MaskedStoreOp op) {
562 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
563 op.getValueToStore());
565 .DefaultUnreachable(
"unexpected operation");
569template <
typename OpTy>
570LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
573 .template getDefiningOp<memref::CollapseShapeOp>();
575 if (!collapseShapeOp)
579 storeOp.getIndices().end());
582 if (
auto affineStoreOp =
583 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
584 AffineMap affineMap = affineStoreOp.getAffineMap();
586 affineMap,
indices, storeOp.getLoc(), rewriter);
587 indices.assign(expandedIndices.begin(), expandedIndices.end());
591 storeOp.getLoc(), rewriter, collapseShapeOp,
indices, sourceIndices)))
594 .Case([&](affine::AffineStoreOp op) {
596 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
599 .Case([&](memref::StoreOp op) {
601 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
602 sourceIndices, op.getNontemporal());
604 .Case([&](vector::StoreOp op) {
606 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
607 sourceIndices, op.getNontemporal());
609 .Case([&](vector::MaskedStoreOp op) {
611 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
612 op.getValueToStore());
614 .DefaultUnreachable(
"unexpected operation");
618LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
621 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
624 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
626 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
628 if (!(srcSubViewOp || dstSubViewOp))
630 "source or destination");
633 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
634 copyOp.getSrcIndices().end());
635 SmallVector<Value> foldedSrcIndices(srcindices);
638 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
640 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
641 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
642 srcindices, foldedSrcIndices);
646 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
647 copyOp.getDstIndices().end());
648 SmallVector<Value> foldedDstIndices(dstindices);
651 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
653 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
654 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
655 dstindices, foldedDstIndices);
661 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
662 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
664 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
665 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
666 copyOp.getBypassL1Attr());
672 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
673 LoadOpOfSubViewOpFolder<memref::LoadOp>,
674 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
675 LoadOpOfSubViewOpFolder<vector::LoadOp>,
676 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
677 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
678 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
679 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
680 StoreOpOfSubViewOpFolder<memref::StoreOp>,
681 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
682 StoreOpOfSubViewOpFolder<vector::StoreOp>,
683 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
684 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
685 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
686 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
687 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
688 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
689 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
690 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
691 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
692 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
693 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
694 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
695 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
696 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
697 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
698 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
699 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
700 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
701 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
702 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
712struct FoldMemRefAliasOpsPass final
714 void runOnOperation()
override;
719void FoldMemRefAliasOpsPass::runOnOperation() {
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...