25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "fold-memref-alias-ops"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
36#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
47template <
typename LoadOrStoreOpTy>
49 return op.getMemref();
57 return op.getSrcMemref();
73 return op.getSrcMemref();
77 return op.getDstMemref();
86template <
typename OpTy>
89 using OpRewritePattern<OpTy>::OpRewritePattern;
91 LogicalResult matchAndRewrite(OpTy loadOp,
92 PatternRewriter &rewriter)
const override;
96template <
typename OpTy>
99 using OpRewritePattern<OpTy>::OpRewritePattern;
101 LogicalResult matchAndRewrite(OpTy loadOp,
102 PatternRewriter &rewriter)
const override;
106template <
typename OpTy>
109 using OpRewritePattern<OpTy>::OpRewritePattern;
111 LogicalResult matchAndRewrite(OpTy loadOp,
112 PatternRewriter &rewriter)
const override;
116template <
typename OpTy>
119 using OpRewritePattern<OpTy>::OpRewritePattern;
121 LogicalResult matchAndRewrite(OpTy storeOp,
122 PatternRewriter &rewriter)
const override;
126template <
typename OpTy>
129 using OpRewritePattern<OpTy>::OpRewritePattern;
131 LogicalResult matchAndRewrite(OpTy storeOp,
132 PatternRewriter &rewriter)
const override;
136template <
typename OpTy>
139 using OpRewritePattern<OpTy>::OpRewritePattern;
141 LogicalResult matchAndRewrite(OpTy storeOp,
142 PatternRewriter &rewriter)
const override;
148 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
150 LogicalResult matchAndRewrite(memref::SubViewOp subView,
151 PatternRewriter &rewriter)
const override {
152 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
157 if (!subView.hasUnitStride()) {
160 if (!srcSubView.hasUnitStride()) {
165 SmallVector<OpFoldResult> resolvedSizes;
166 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
168 subView.getMixedSizes(), srcDroppedDims,
172 SmallVector<Value> resolvedOffsets;
174 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
175 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
180 subView, subView.getType(), srcSubView.getSource(),
182 srcSubView.getMixedStrides());
190class NVGPUAsyncCopyOpSubViewOpFolder final
193 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
195 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
196 PatternRewriter &rewriter)
const override;
200template <
typename XferOp>
203 memref::SubViewOp subviewOp) {
205 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
206 "must be a vector transfer op");
207 if (xferOp.hasOutOfBoundsDim())
209 if (!subviewOp.hasUnitStride()) {
211 xferOp,
"non-1 stride subview, need to track strides in folded memref");
218 memref::SubViewOp subviewOp) {
223 vector::TransferReadOp readOp,
224 memref::SubViewOp subviewOp) {
229 vector::TransferWriteOp writeOp,
230 memref::SubViewOp subviewOp) {
234template <
typename OpTy>
235LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
243 LogicalResult preconditionResult =
245 if (
failed(preconditionResult))
246 return preconditionResult;
250 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
251 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
252 loadOp.getIndices(), sourceIndices);
255 .Case([&](memref::LoadOp op) {
257 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
259 .Case([&](vector::LoadOp op) {
261 op, op.getType(), subViewOp.getSource(), sourceIndices);
263 .Case([&](vector::MaskedLoadOp op) {
265 op, op.getType(), subViewOp.getSource(), sourceIndices,
266 op.getMask(), op.getPassThru());
268 .Case([&](vector::TransferReadOp op) {
270 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
272 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
273 subViewOp.getDroppedDims())),
274 op.getPadding(), op.getMask(), op.getInBoundsAttr());
276 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
278 op, op.getType(), subViewOp.getSource(), sourceIndices,
279 op.getLeadDimension(), op.getTransposeAttr());
281 .Case([&](nvgpu::LdMatrixOp op) {
283 op, op.getType(), subViewOp.getSource(), sourceIndices,
284 op.getTranspose(), op.getNumTiles());
286 .DefaultUnreachable(
"unexpected operation");
290template <
typename OpTy>
291LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
303 loadOp.getIndices(), sourceIndices,
304 isa<memref::LoadOp>(loadOp.getOperation()));
307 .Case([&](memref::LoadOp op) {
309 loadOp, expandShapeOp.getViewSource(), sourceIndices,
310 op.getNontemporal());
313 .Case([&](vector::LoadOp op) {
315 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
316 op.getNontemporal());
319 .Case([&](vector::MaskedLoadOp op) {
321 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
322 op.getMask(), op.getPassThru());
325 .Case([&](vector::TransferReadOp op) {
327 if (!op.getPermutationMap().isMinorIdentity())
332 const int64_t sourceRank = sourceIndices.size();
333 const int64_t vectorRank = op.getVectorType().getRank();
334 if (sourceRank < vectorRank)
343 op, op.getVectorType(), expandShapeOp.getViewSource(),
344 sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
348 .DefaultUnreachable(
"unexpected operation");
351template <
typename OpTy>
352LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
355 .template getDefiningOp<memref::CollapseShapeOp>();
357 if (!collapseShapeOp)
362 loadOp.getIndices(), sourceIndices);
364 .Case([&](memref::LoadOp op) {
366 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
367 op.getNontemporal());
369 .Case([&](vector::LoadOp op) {
371 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
372 op.getNontemporal());
374 .Case([&](vector::MaskedLoadOp op) {
376 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
377 op.getMask(), op.getPassThru());
379 .DefaultUnreachable(
"unexpected operation");
383template <
typename OpTy>
384LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
392 LogicalResult preconditionResult =
394 if (failed(preconditionResult))
395 return preconditionResult;
399 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
400 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
401 storeOp.getIndices(), sourceIndices);
404 .Case([&](memref::StoreOp op) {
406 op, op.getValue(), subViewOp.getSource(), sourceIndices,
407 op.getNontemporal());
409 .Case([&](vector::TransferWriteOp op) {
411 op, op.getValue(), subViewOp.getSource(), sourceIndices,
413 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
414 subViewOp.getDroppedDims())),
415 op.getMask(), op.getInBoundsAttr());
417 .Case([&](vector::StoreOp op) {
419 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
421 .Case([&](vector::MaskedStoreOp op) {
423 op, subViewOp.getSource(), sourceIndices, op.getMask(),
424 op.getValueToStore());
426 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
428 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
429 op.getLeadDimension(), op.getTransposeAttr());
431 .DefaultUnreachable(
"unexpected operation");
435template <
typename OpTy>
436LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
448 storeOp.getIndices(), sourceIndices,
449 isa<memref::StoreOp>(storeOp.getOperation()));
451 .Case([&](memref::StoreOp op) {
453 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
454 sourceIndices, op.getNontemporal());
456 .Case([&](vector::StoreOp op) {
458 op, op.getValueToStore(), expandShapeOp.getViewSource(),
459 sourceIndices, op.getNontemporal());
461 .Case([&](vector::MaskedStoreOp op) {
463 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
464 op.getValueToStore());
466 .DefaultUnreachable(
"unexpected operation");
470template <
typename OpTy>
471LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
474 .template getDefiningOp<memref::CollapseShapeOp>();
476 if (!collapseShapeOp)
481 storeOp.getIndices(), sourceIndices);
483 .Case([&](memref::StoreOp op) {
485 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
486 sourceIndices, op.getNontemporal());
488 .Case([&](vector::StoreOp op) {
490 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
491 sourceIndices, op.getNontemporal());
493 .Case([&](vector::MaskedStoreOp op) {
495 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
496 op.getValueToStore());
498 .DefaultUnreachable(
"unexpected operation");
502LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
505 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
508 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
510 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
512 if (!(srcSubViewOp || dstSubViewOp))
514 "source or destination");
517 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
518 copyOp.getSrcIndices().end());
521 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
523 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
524 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
525 copyOp.getSrcIndices(), foldedSrcIndices);
529 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
530 copyOp.getDstIndices().end());
533 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
535 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
536 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
537 copyOp.getDstIndices(), foldedDstIndices);
543 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
544 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
546 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
547 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
548 copyOp.getBypassL1Attr());
554 patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
555 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
556 LoadOpOfSubViewOpFolder<vector::LoadOp>,
557 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
558 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
559 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
560 StoreOpOfSubViewOpFolder<memref::StoreOp>,
561 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
562 StoreOpOfSubViewOpFolder<vector::StoreOp>,
563 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
564 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
565 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
566 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
567 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
568 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
569 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
570 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
571 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
572 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
573 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
574 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
575 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
576 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
577 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
578 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
588struct FoldMemRefAliasOpsPass final
590 void runOnOperation()
override;
595void FoldMemRefAliasOpsPass::runOnOperation() {
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...