25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "fold-memref-alias-ops"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
36#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
47template <
typename LoadOrStoreOpTy>
49 return op.getMemref();
57 return op.getSrcMemref();
73 return op.getSrcMemref();
77 return op.getDstMemref();
86template <
typename OpTy>
89 using OpRewritePattern<OpTy>::OpRewritePattern;
91 LogicalResult matchAndRewrite(OpTy loadOp,
92 PatternRewriter &rewriter)
const override;
96template <
typename OpTy>
99 using OpRewritePattern<OpTy>::OpRewritePattern;
101 LogicalResult matchAndRewrite(OpTy loadOp,
102 PatternRewriter &rewriter)
const override;
106template <
typename OpTy>
109 using OpRewritePattern<OpTy>::OpRewritePattern;
111 LogicalResult matchAndRewrite(OpTy loadOp,
112 PatternRewriter &rewriter)
const override;
116template <
typename OpTy>
119 using OpRewritePattern<OpTy>::OpRewritePattern;
121 LogicalResult matchAndRewrite(OpTy storeOp,
122 PatternRewriter &rewriter)
const override;
126template <
typename OpTy>
129 using OpRewritePattern<OpTy>::OpRewritePattern;
131 LogicalResult matchAndRewrite(OpTy storeOp,
132 PatternRewriter &rewriter)
const override;
136template <
typename OpTy>
139 using OpRewritePattern<OpTy>::OpRewritePattern;
141 LogicalResult matchAndRewrite(OpTy storeOp,
142 PatternRewriter &rewriter)
const override;
148 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
150 LogicalResult matchAndRewrite(memref::SubViewOp subView,
151 PatternRewriter &rewriter)
const override {
152 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
157 if (!subView.hasUnitStride()) {
160 if (!srcSubView.hasUnitStride()) {
165 SmallVector<OpFoldResult> resolvedSizes;
166 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
167 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
168 subView.getMixedSizes(), srcDroppedDims,
172 SmallVector<Value> resolvedOffsets;
173 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
174 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
175 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
180 subView, subView.getType(), srcSubView.getSource(),
182 srcSubView.getMixedStrides());
190class NVGPUAsyncCopyOpSubViewOpFolder final
193 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
195 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
196 PatternRewriter &rewriter)
const override;
200template <
typename XferOp>
203 memref::SubViewOp subviewOp) {
205 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
206 "must be a vector transfer op");
207 if (xferOp.hasOutOfBoundsDim())
209 if (!subviewOp.hasUnitStride()) {
211 xferOp,
"non-1 stride subview, need to track strides in folded memref");
218 memref::SubViewOp subviewOp) {
223 vector::TransferReadOp readOp,
224 memref::SubViewOp subviewOp) {
229 vector::TransferWriteOp writeOp,
230 memref::SubViewOp subviewOp) {
234template <
typename OpTy>
235LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
243 LogicalResult preconditionResult =
245 if (
failed(preconditionResult))
246 return preconditionResult;
250 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
251 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
252 loadOp.getIndices(), sourceIndices);
255 .Case([&](memref::LoadOp op) {
257 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
259 .Case([&](vector::LoadOp op) {
261 op, op.getType(), subViewOp.getSource(), sourceIndices);
263 .Case([&](vector::MaskedLoadOp op) {
265 op, op.getType(), subViewOp.getSource(), sourceIndices,
266 op.getMask(), op.getPassThru());
268 .Case([&](vector::TransferReadOp op) {
270 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
272 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
273 subViewOp.getDroppedDims())),
274 op.getPadding(), op.getMask(), op.getInBoundsAttr());
276 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
278 op, op.getType(), subViewOp.getSource(), sourceIndices,
279 op.getLeadDimension(), op.getTransposeAttr());
281 .Case([&](nvgpu::LdMatrixOp op) {
283 op, op.getType(), subViewOp.getSource(), sourceIndices,
284 op.getTranspose(), op.getNumTiles());
286 .DefaultUnreachable(
"unexpected operation");
290template <
typename OpTy>
291LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
303 loadOp.getIndices(), sourceIndices,
304 isa<memref::LoadOp>(loadOp.getOperation()));
307 .Case([&](memref::LoadOp op) {
309 loadOp, expandShapeOp.getViewSource(), sourceIndices,
310 op.getNontemporal());
313 .Case([&](vector::LoadOp op) {
315 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
316 op.getNontemporal());
319 .Case([&](vector::MaskedLoadOp op) {
321 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
322 op.getMask(), op.getPassThru());
325 .Case([&](vector::TransferReadOp op) {
328 const int64_t vectorRank = op.getVectorType().getRank();
329 const int64_t sourceRank = sourceIndices.size();
330 if (sourceRank < vectorRank)
337 bool foundExpr =
false;
339 for (
auto reassocationIndices :
340 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
341 auto reassociation = reassocationIndices.value();
344 reassociation[reassociation.size() - 1], rewriter.
getContext());
360 op, op.getVectorType(), expandShapeOp.getViewSource(),
361 sourceIndices, newMap, op.getPadding(), op.getMask(),
365 .DefaultUnreachable(
"unexpected operation");
368template <
typename OpTy>
369LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
372 .template getDefiningOp<memref::CollapseShapeOp>();
374 if (!collapseShapeOp)
379 loadOp.getIndices(), sourceIndices);
381 .Case([&](memref::LoadOp op) {
383 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
384 op.getNontemporal());
386 .Case([&](vector::LoadOp op) {
388 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
389 op.getNontemporal());
391 .Case([&](vector::MaskedLoadOp op) {
393 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
394 op.getMask(), op.getPassThru());
396 .DefaultUnreachable(
"unexpected operation");
400template <
typename OpTy>
401LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
409 LogicalResult preconditionResult =
411 if (
failed(preconditionResult))
412 return preconditionResult;
416 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
417 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
418 storeOp.getIndices(), sourceIndices);
421 .Case([&](memref::StoreOp op) {
423 op, op.getValue(), subViewOp.getSource(), sourceIndices,
424 op.getNontemporal());
426 .Case([&](vector::TransferWriteOp op) {
428 op, op.getValue(), subViewOp.getSource(), sourceIndices,
430 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
431 subViewOp.getDroppedDims())),
432 op.getMask(), op.getInBoundsAttr());
434 .Case([&](vector::StoreOp op) {
436 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
438 .Case([&](vector::MaskedStoreOp op) {
440 op, subViewOp.getSource(), sourceIndices, op.getMask(),
441 op.getValueToStore());
443 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
445 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
446 op.getLeadDimension(), op.getTransposeAttr());
448 .DefaultUnreachable(
"unexpected operation");
452template <
typename OpTy>
453LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
465 storeOp.getIndices(), sourceIndices,
466 isa<memref::StoreOp>(storeOp.getOperation()));
468 .Case([&](memref::StoreOp op) {
470 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
471 sourceIndices, op.getNontemporal());
473 .Case([&](vector::StoreOp op) {
475 op, op.getValueToStore(), expandShapeOp.getViewSource(),
476 sourceIndices, op.getNontemporal());
478 .Case([&](vector::MaskedStoreOp op) {
480 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
481 op.getValueToStore());
483 .DefaultUnreachable(
"unexpected operation");
487template <
typename OpTy>
488LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
491 .template getDefiningOp<memref::CollapseShapeOp>();
493 if (!collapseShapeOp)
498 storeOp.getIndices(), sourceIndices);
500 .Case([&](memref::StoreOp op) {
502 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
503 sourceIndices, op.getNontemporal());
505 .Case([&](vector::StoreOp op) {
507 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
508 sourceIndices, op.getNontemporal());
510 .Case([&](vector::MaskedStoreOp op) {
512 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
513 op.getValueToStore());
515 .DefaultUnreachable(
"unexpected operation");
519LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
522 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
525 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
527 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
529 if (!(srcSubViewOp || dstSubViewOp))
531 "source or destination");
534 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
535 copyOp.getSrcIndices().end());
538 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
539 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
540 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
541 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
542 copyOp.getSrcIndices(), foldedSrcIndices);
546 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
547 copyOp.getDstIndices().end());
550 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
551 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
552 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
553 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
554 copyOp.getDstIndices(), foldedDstIndices);
560 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
561 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
563 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
564 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
565 copyOp.getBypassL1Attr());
571 patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
572 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
573 LoadOpOfSubViewOpFolder<vector::LoadOp>,
574 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
575 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
576 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
577 StoreOpOfSubViewOpFolder<memref::StoreOp>,
578 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
579 StoreOpOfSubViewOpFolder<vector::StoreOp>,
580 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
581 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
582 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
583 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
584 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
585 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
586 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
587 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
588 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
589 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
590 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
591 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
592 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
593 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
594 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
595 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
605struct FoldMemRefAliasOpsPass final
606 :
public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
607 void runOnOperation()
override;
612void FoldMemRefAliasOpsPass::runOnOperation() {
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...