26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
31 #define DEBUG_TYPE "fold-memref-alias-ops"
32 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36 #define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
48 template <
typename LoadOrStoreOpTy>
50 return op.getMemref();
58 return op.getSrcMemref();
74 return op.getSrcMemref();
78 return op.getDstMemref();
87 template <
typename OpTy>
92 LogicalResult matchAndRewrite(OpTy loadOp,
97 template <
typename OpTy>
102 LogicalResult matchAndRewrite(OpTy loadOp,
107 template <
typename OpTy>
112 LogicalResult matchAndRewrite(OpTy loadOp,
117 template <
typename OpTy>
122 LogicalResult matchAndRewrite(OpTy storeOp,
127 template <
typename OpTy>
132 LogicalResult matchAndRewrite(OpTy storeOp,
137 template <
typename OpTy>
138 class StoreOpOfCollapseShapeOpFolder final :
public OpRewritePattern<OpTy> {
142 LogicalResult matchAndRewrite(OpTy storeOp,
151 LogicalResult matchAndRewrite(memref::SubViewOp subView,
153 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
158 if (!subView.hasUnitStride()) {
161 if (!srcSubView.hasUnitStride()) {
167 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
169 subView.getMixedSizes(), srcDroppedDims,
175 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
176 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
181 subView, subView.getType(), srcSubView.getSource(),
183 srcSubView.getMixedStrides());
191 class NVGPUAsyncCopyOpSubViewOpFolder final
196 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
208 for (
unsigned i = 0, e = affineMap.
getNumResults(); i < e; i++) {
210 rewriter, loc, affineMap.
getSubMap({i}), indicesOfr);
211 expandedIndices.push_back(
214 return expandedIndices;
217 template <
typename XferOp>
220 memref::SubViewOp subviewOp) {
222 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
223 "must be a vector transfer op");
224 if (xferOp.hasOutOfBoundsDim())
226 if (!subviewOp.hasUnitStride()) {
228 xferOp,
"non-1 stride subview, need to track strides in folded memref");
235 memref::SubViewOp subviewOp) {
240 vector::TransferReadOp readOp,
241 memref::SubViewOp subviewOp) {
246 vector::TransferWriteOp writeOp,
247 memref::SubViewOp subviewOp) {
251 template <
typename OpTy>
252 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
260 LogicalResult preconditionResult =
262 if (
failed(preconditionResult))
263 return preconditionResult;
266 loadOp.getIndices().end());
269 if (
auto affineLoadOp =
270 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
271 AffineMap affineMap = affineLoadOp.getAffineMap();
273 affineMap, indices, loadOp.getLoc(), rewriter);
274 indices.assign(expandedIndices.begin(), expandedIndices.end());
278 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
279 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
283 .Case([&](affine::AffineLoadOp op) {
285 loadOp, subViewOp.getSource(), sourceIndices);
287 .Case([&](memref::LoadOp op) {
289 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
291 .Case([&](vector::LoadOp op) {
293 op, op.getType(), subViewOp.getSource(), sourceIndices);
295 .Case([&](vector::MaskedLoadOp op) {
297 op, op.getType(), subViewOp.getSource(), sourceIndices,
298 op.getMask(), op.getPassThru());
300 .Case([&](vector::TransferReadOp op) {
302 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
304 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
305 subViewOp.getDroppedDims())),
306 op.getPadding(), op.getMask(), op.getInBoundsAttr());
308 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
310 op, op.getType(), subViewOp.getSource(), sourceIndices,
311 op.getLeadDimension(), op.getTransposeAttr());
313 .Case([&](nvgpu::LdMatrixOp op) {
315 op, op.getType(), subViewOp.getSource(), sourceIndices,
316 op.getTranspose(), op.getNumTiles());
318 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
322 template <
typename OpTy>
323 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
332 loadOp.getIndices().end());
335 if (
auto affineLoadOp =
336 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
337 AffineMap affineMap = affineLoadOp.getAffineMap();
339 affineMap, indices, loadOp.getLoc(), rewriter);
340 indices.assign(expandedIndices.begin(), expandedIndices.end());
347 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
348 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
351 .Case([&](affine::AffineLoadOp op) {
353 loadOp, expandShapeOp.getViewSource(), sourceIndices);
355 .Case([&](memref::LoadOp op) {
357 loadOp, expandShapeOp.getViewSource(), sourceIndices,
358 op.getNontemporal());
360 .Case([&](vector::LoadOp op) {
362 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
363 op.getNontemporal());
365 .Case([&](vector::MaskedLoadOp op) {
367 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
368 op.getMask(), op.getPassThru());
370 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
374 template <
typename OpTy>
375 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
378 .template getDefiningOp<memref::CollapseShapeOp>();
380 if (!collapseShapeOp)
384 loadOp.getIndices().end());
387 if (
auto affineLoadOp =
388 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
389 AffineMap affineMap = affineLoadOp.getAffineMap();
391 affineMap, indices, loadOp.getLoc(), rewriter);
392 indices.assign(expandedIndices.begin(), expandedIndices.end());
396 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
399 .Case([&](affine::AffineLoadOp op) {
401 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
403 .Case([&](memref::LoadOp op) {
405 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
406 op.getNontemporal());
408 .Case([&](vector::LoadOp op) {
410 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
411 op.getNontemporal());
413 .Case([&](vector::MaskedLoadOp op) {
415 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
416 op.getMask(), op.getPassThru());
418 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
422 template <
typename OpTy>
423 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
431 LogicalResult preconditionResult =
433 if (
failed(preconditionResult))
434 return preconditionResult;
437 storeOp.getIndices().end());
440 if (
auto affineStoreOp =
441 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
442 AffineMap affineMap = affineStoreOp.getAffineMap();
444 affineMap, indices, storeOp.getLoc(), rewriter);
445 indices.assign(expandedIndices.begin(), expandedIndices.end());
449 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
450 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
454 .Case([&](affine::AffineStoreOp op) {
456 op, op.getValue(), subViewOp.getSource(), sourceIndices);
458 .Case([&](memref::StoreOp op) {
460 op, op.getValue(), subViewOp.getSource(), sourceIndices,
461 op.getNontemporal());
463 .Case([&](vector::TransferWriteOp op) {
465 op, op.getValue(), subViewOp.getSource(), sourceIndices,
467 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
468 subViewOp.getDroppedDims())),
469 op.getMask(), op.getInBoundsAttr());
471 .Case([&](vector::StoreOp op) {
473 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
475 .Case([&](vector::MaskedStoreOp op) {
477 op, subViewOp.getSource(), sourceIndices, op.getMask(),
478 op.getValueToStore());
480 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
482 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
483 op.getLeadDimension(), op.getTransposeAttr());
485 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
489 template <
typename OpTy>
490 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
499 storeOp.getIndices().end());
502 if (
auto affineStoreOp =
503 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
504 AffineMap affineMap = affineStoreOp.getAffineMap();
506 affineMap, indices, storeOp.getLoc(), rewriter);
507 indices.assign(expandedIndices.begin(), expandedIndices.end());
514 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
515 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
518 .Case([&](affine::AffineStoreOp op) {
520 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
523 .Case([&](memref::StoreOp op) {
525 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
526 sourceIndices, op.getNontemporal());
528 .Case([&](vector::StoreOp op) {
530 op, op.getValueToStore(), expandShapeOp.getViewSource(),
531 sourceIndices, op.getNontemporal());
533 .Case([&](vector::MaskedStoreOp op) {
535 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
536 op.getValueToStore());
538 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
542 template <
typename OpTy>
543 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
546 .template getDefiningOp<memref::CollapseShapeOp>();
548 if (!collapseShapeOp)
552 storeOp.getIndices().end());
555 if (
auto affineStoreOp =
556 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
557 AffineMap affineMap = affineStoreOp.getAffineMap();
559 affineMap, indices, storeOp.getLoc(), rewriter);
560 indices.assign(expandedIndices.begin(), expandedIndices.end());
564 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
567 .Case([&](affine::AffineStoreOp op) {
569 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
572 .Case([&](memref::StoreOp op) {
574 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
575 sourceIndices, op.getNontemporal());
577 .Case([&](vector::StoreOp op) {
579 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
580 sourceIndices, op.getNontemporal());
582 .Case([&](vector::MaskedStoreOp op) {
584 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
585 op.getValueToStore());
587 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
591 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
594 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
597 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
599 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
601 if (!(srcSubViewOp || dstSubViewOp))
603 "source or destination");
607 copyOp.getSrcIndices().end());
611 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
613 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
614 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
615 srcindices, foldedSrcIndices);
620 copyOp.getDstIndices().end());
624 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
626 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
627 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
628 dstindices, foldedDstIndices);
635 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
637 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
638 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
639 copyOp.getBypassL1Attr());
645 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
646 LoadOpOfSubViewOpFolder<memref::LoadOp>,
647 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
648 LoadOpOfSubViewOpFolder<vector::LoadOp>,
649 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
650 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
651 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
652 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
653 StoreOpOfSubViewOpFolder<memref::StoreOp>,
654 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
655 StoreOpOfSubViewOpFolder<vector::StoreOp>,
656 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
657 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
658 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
659 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
660 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
661 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
662 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
663 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
664 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
665 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
666 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
667 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
668 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
669 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
670 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
671 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
672 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
673 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
674 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
684 struct FoldMemRefAliasOpsPass final
685 :
public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
686 void runOnOperation()
override;
691 void FoldMemRefAliasOpsPass::runOnOperation() {
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...