28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "fold-memref-alias-ops"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
64 memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
71 assert(!group.empty() &&
"association indices groups cannot be empty");
72 int64_t groupSize = group.size();
74 sourceIndices.push_back(indices[group[0]]);
78 llvm::map_to_vector(group, [&](int64_t d) {
return destShape[d]; });
80 llvm::map_to_vector(group, [&](int64_t d) {
return indices[d]; });
81 Value collapsedIndex = rewriter.
create<affine::AffineLinearizeIndexOp>(
82 loc, groupIndices, groupBasis, startsInbounds);
83 sourceIndices.push_back(collapsedIndex);
103 memref::CollapseShapeOp collapseShapeOp,
107 auto metadata = rewriter.
create<memref::ExtractStridedMetadataOp>(
108 loc, collapseShapeOp.getSrc());
110 for (
auto [index, group] :
111 llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
112 assert(!group.empty() &&
"association indices groups cannot be empty");
113 int64_t groupSize = group.size();
115 if (groupSize == 1) {
116 sourceIndices.push_back(index);
121 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
123 loc, index, basis,
true);
124 llvm::append_range(sourceIndices,
delinearize.getResults());
126 if (collapseShapeOp.getReassociationIndices().empty()) {
129 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
132 for (int64_t i = 0; i < srcRank; i++) {
133 sourceIndices.push_back(
141 template <
typename LoadOrStoreOpTy>
143 return op.getMemref();
151 return op.getSrcMemref();
167 return op.getSrcMemref();
171 return op.getDstMemref();
180 template <
typename OpTy>
185 LogicalResult matchAndRewrite(OpTy loadOp,
190 template <
typename OpTy>
195 LogicalResult matchAndRewrite(OpTy loadOp,
200 template <
typename OpTy>
205 LogicalResult matchAndRewrite(OpTy loadOp,
210 template <
typename OpTy>
215 LogicalResult matchAndRewrite(OpTy storeOp,
220 template <
typename OpTy>
225 LogicalResult matchAndRewrite(OpTy storeOp,
230 template <
typename OpTy>
231 class StoreOpOfCollapseShapeOpFolder final :
public OpRewritePattern<OpTy> {
235 LogicalResult matchAndRewrite(OpTy storeOp,
244 LogicalResult matchAndRewrite(memref::SubViewOp subView,
246 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
251 if (!subView.hasUnitStride()) {
254 if (!srcSubView.hasUnitStride()) {
260 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
262 subView.getMixedSizes(), srcDroppedDims,
268 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
269 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
274 subView, subView.getType(), srcSubView.getSource(),
276 srcSubView.getMixedStrides());
284 class NVGPUAsyncCopyOpSubViewOpFolder final
289 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
301 for (
unsigned i = 0, e = affineMap.
getNumResults(); i < e; i++) {
303 rewriter, loc, affineMap.
getSubMap({i}), indicesOfr);
304 expandedIndices.push_back(
307 return expandedIndices;
310 template <
typename XferOp>
313 memref::SubViewOp subviewOp) {
315 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
316 "must be a vector transfer op");
317 if (xferOp.hasOutOfBoundsDim())
319 if (!subviewOp.hasUnitStride()) {
321 xferOp,
"non-1 stride subview, need to track strides in folded memref");
328 memref::SubViewOp subviewOp) {
333 vector::TransferReadOp readOp,
334 memref::SubViewOp subviewOp) {
339 vector::TransferWriteOp writeOp,
340 memref::SubViewOp subviewOp) {
344 template <
typename OpTy>
345 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
353 LogicalResult preconditionResult =
355 if (failed(preconditionResult))
356 return preconditionResult;
359 loadOp.getIndices().end());
362 if (
auto affineLoadOp =
363 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
364 AffineMap affineMap = affineLoadOp.getAffineMap();
366 affineMap, indices, loadOp.getLoc(), rewriter);
367 indices.assign(expandedIndices.begin(), expandedIndices.end());
371 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
372 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
376 .Case([&](affine::AffineLoadOp op) {
378 loadOp, subViewOp.getSource(), sourceIndices);
380 .Case([&](memref::LoadOp op) {
382 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
384 .Case([&](vector::LoadOp op) {
386 op, op.getType(), subViewOp.getSource(), sourceIndices);
388 .Case([&](vector::MaskedLoadOp op) {
390 op, op.getType(), subViewOp.getSource(), sourceIndices,
391 op.getMask(), op.getPassThru());
393 .Case([&](vector::TransferReadOp op) {
395 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
397 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
398 subViewOp.getDroppedDims())),
399 op.getPadding(), op.getMask(), op.getInBoundsAttr());
401 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
403 op, op.getType(), subViewOp.getSource(), sourceIndices,
404 op.getLeadDimension(), op.getTransposeAttr());
406 .Case([&](nvgpu::LdMatrixOp op) {
408 op, op.getType(), subViewOp.getSource(), sourceIndices,
409 op.getTranspose(), op.getNumTiles());
411 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
415 template <
typename OpTy>
416 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
425 loadOp.getIndices().end());
428 if (
auto affineLoadOp =
429 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
430 AffineMap affineMap = affineLoadOp.getAffineMap();
432 affineMap, indices, loadOp.getLoc(), rewriter);
433 indices.assign(expandedIndices.begin(), expandedIndices.end());
440 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
441 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
444 .Case([&](affine::AffineLoadOp op) {
446 loadOp, expandShapeOp.getViewSource(), sourceIndices);
448 .Case([&](memref::LoadOp op) {
450 loadOp, expandShapeOp.getViewSource(), sourceIndices,
451 op.getNontemporal());
453 .Case([&](vector::LoadOp op) {
455 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
456 op.getNontemporal());
458 .Case([&](vector::MaskedLoadOp op) {
460 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
461 op.getMask(), op.getPassThru());
463 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
467 template <
typename OpTy>
468 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
471 .template getDefiningOp<memref::CollapseShapeOp>();
473 if (!collapseShapeOp)
477 loadOp.getIndices().end());
480 if (
auto affineLoadOp =
481 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
482 AffineMap affineMap = affineLoadOp.getAffineMap();
484 affineMap, indices, loadOp.getLoc(), rewriter);
485 indices.assign(expandedIndices.begin(), expandedIndices.end());
489 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
492 .Case([&](affine::AffineLoadOp op) {
494 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
496 .Case([&](memref::LoadOp op) {
498 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
499 op.getNontemporal());
501 .Case([&](vector::LoadOp op) {
503 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
504 op.getNontemporal());
506 .Case([&](vector::MaskedLoadOp op) {
508 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
509 op.getMask(), op.getPassThru());
511 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
515 template <
typename OpTy>
516 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
524 LogicalResult preconditionResult =
526 if (failed(preconditionResult))
527 return preconditionResult;
530 storeOp.getIndices().end());
533 if (
auto affineStoreOp =
534 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
535 AffineMap affineMap = affineStoreOp.getAffineMap();
537 affineMap, indices, storeOp.getLoc(), rewriter);
538 indices.assign(expandedIndices.begin(), expandedIndices.end());
542 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
543 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
547 .Case([&](affine::AffineStoreOp op) {
549 op, op.getValue(), subViewOp.getSource(), sourceIndices);
551 .Case([&](memref::StoreOp op) {
553 op, op.getValue(), subViewOp.getSource(), sourceIndices,
554 op.getNontemporal());
556 .Case([&](vector::TransferWriteOp op) {
558 op, op.getValue(), subViewOp.getSource(), sourceIndices,
560 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
561 subViewOp.getDroppedDims())),
562 op.getMask(), op.getInBoundsAttr());
564 .Case([&](vector::StoreOp op) {
566 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
568 .Case([&](vector::MaskedStoreOp op) {
570 op, subViewOp.getSource(), sourceIndices, op.getMask(),
571 op.getValueToStore());
573 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
575 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
576 op.getLeadDimension(), op.getTransposeAttr());
578 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
582 template <
typename OpTy>
583 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
592 storeOp.getIndices().end());
595 if (
auto affineStoreOp =
596 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
597 AffineMap affineMap = affineStoreOp.getAffineMap();
599 affineMap, indices, storeOp.getLoc(), rewriter);
600 indices.assign(expandedIndices.begin(), expandedIndices.end());
607 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
608 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
611 .Case([&](affine::AffineStoreOp op) {
613 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
616 .Case([&](memref::StoreOp op) {
618 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
619 sourceIndices, op.getNontemporal());
621 .Case([&](vector::StoreOp op) {
623 op, op.getValueToStore(), expandShapeOp.getViewSource(),
624 sourceIndices, op.getNontemporal());
626 .Case([&](vector::MaskedStoreOp op) {
628 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
629 op.getValueToStore());
631 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
635 template <
typename OpTy>
636 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
639 .template getDefiningOp<memref::CollapseShapeOp>();
641 if (!collapseShapeOp)
645 storeOp.getIndices().end());
648 if (
auto affineStoreOp =
649 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
650 AffineMap affineMap = affineStoreOp.getAffineMap();
652 affineMap, indices, storeOp.getLoc(), rewriter);
653 indices.assign(expandedIndices.begin(), expandedIndices.end());
657 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
660 .Case([&](affine::AffineStoreOp op) {
662 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
665 .Case([&](memref::StoreOp op) {
667 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
668 sourceIndices, op.getNontemporal());
670 .Case([&](vector::StoreOp op) {
672 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
673 sourceIndices, op.getNontemporal());
675 .Case([&](vector::MaskedStoreOp op) {
677 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
678 op.getValueToStore());
680 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
684 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
687 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
690 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
692 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
694 if (!(srcSubViewOp || dstSubViewOp))
696 "source or destination");
700 copyOp.getSrcIndices().end());
704 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
706 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
707 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
708 srcindices, foldedSrcIndices);
713 copyOp.getDstIndices().end());
717 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
719 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
720 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
721 dstindices, foldedDstIndices);
728 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
730 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
731 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
732 copyOp.getBypassL1Attr());
738 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
739 LoadOpOfSubViewOpFolder<memref::LoadOp>,
740 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
741 LoadOpOfSubViewOpFolder<vector::LoadOp>,
742 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
743 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
744 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
745 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
746 StoreOpOfSubViewOpFolder<memref::StoreOp>,
747 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
748 StoreOpOfSubViewOpFolder<vector::StoreOp>,
749 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
750 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
751 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
752 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
753 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
754 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
755 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
756 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
757 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
758 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
759 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
760 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
761 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
762 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
763 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
764 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
765 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
766 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
767 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
777 struct FoldMemRefAliasOpsPass final
778 :
public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
779 void runOnOperation()
override;
784 void FoldMemRefAliasOpsPass::runOnOperation() {
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...