27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
32 #define DEBUG_TYPE "fold-memref-alias-ops"
33 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
38 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
63 memref::ExpandShapeOp expandShapeOp,
69 if (!expandShapeOp.getResultType().hasStaticShape())
74 assert(!groups.empty() &&
"association indices groups cannot be empty");
75 int64_t groupSize = groups.size();
80 for (int64_t i = 0; i < groupSize; ++i)
81 sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
89 for (int64_t i = 0; i < groupSize; i++)
90 dynamicIndices[i] = indices[groups[i]];
99 sourceIndices.push_back(
120 memref::CollapseShapeOp collapseShapeOp,
127 assert(!groups.empty() &&
"association indices groups cannot be empty");
128 dynamicIndices.push_back(indices[cnt++]);
129 int64_t groupSize = groups.size();
136 for (int64_t i = 1; i < groupSize; ++i) {
137 sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
138 if (sizes[i] == ShapedType::kDynamic)
149 for (int64_t i = 0; i < groupSize; i++) {
153 delinearizingExprs[i]),
155 sourceIndices.push_back(
158 dynamicIndices.clear();
160 if (collapseShapeOp.getReassociationIndices().empty()) {
163 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
164 for (int64_t i = 0; i < srcRank; i++) {
166 rewriter, loc, zeroAffineMap, dynamicIndices);
167 sourceIndices.push_back(
175 template <
typename LoadOrStoreOpTy>
177 return op.getMemref();
181 return op.getSource();
185 return op.getSrcMemref();
197 return op.getSource();
201 return op.getSrcMemref();
205 return op.getDstMemref();
214 template <
typename OpTy>
224 template <
typename OpTy>
234 template <
typename OpTy>
244 template <
typename OpTy>
254 template <
typename OpTy>
264 template <
typename OpTy>
265 class StoreOpOfCollapseShapeOpFolder final :
public OpRewritePattern<OpTy> {
280 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
285 if (!subView.hasUnitStride()) {
288 if (!srcSubView.hasUnitStride()) {
294 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
296 subView.getMixedSizes(), srcDroppedDims,
302 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
303 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
308 subView, subView.getType(), srcSubView.getSource(),
310 srcSubView.getMixedStrides());
318 class NvgpuAsyncCopyOpSubViewOpFolder final
323 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
335 for (
unsigned i = 0, e = affineMap.
getNumResults(); i < e; i++) {
337 rewriter, loc, affineMap.
getSubMap({i}), indicesOfr);
338 expandedIndices.push_back(
341 return expandedIndices;
344 template <
typename XferOp>
347 memref::SubViewOp subviewOp) {
349 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
350 "must be a vector transfer op");
351 if (xferOp.hasOutOfBoundsDim())
353 if (!subviewOp.hasUnitStride()) {
355 xferOp,
"non-1 stride subview, need to track strides in folded memref");
362 memref::SubViewOp subviewOp) {
367 vector::TransferReadOp readOp,
368 memref::SubViewOp subviewOp) {
373 vector::TransferWriteOp writeOp,
374 memref::SubViewOp subviewOp) {
378 template <
typename OpTy>
379 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
389 if (
failed(preconditionResult))
390 return preconditionResult;
393 loadOp.getIndices().end());
396 if (
auto affineLoadOp =
397 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
398 AffineMap affineMap = affineLoadOp.getAffineMap();
400 affineMap, indices, loadOp.getLoc(), rewriter);
401 indices.assign(expandedIndices.begin(), expandedIndices.end());
405 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
406 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
410 .Case([&](affine::AffineLoadOp op) {
412 loadOp, subViewOp.getSource(), sourceIndices);
414 .Case([&](memref::LoadOp op) {
416 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
418 .Case([&](vector::LoadOp op) {
420 op, op.getType(), subViewOp.getSource(), sourceIndices);
422 .Case([&](vector::MaskedLoadOp op) {
424 op, op.getType(), subViewOp.getSource(), sourceIndices,
425 op.getMask(), op.getPassThru());
427 .Case([&](vector::TransferReadOp op) {
429 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
431 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
432 subViewOp.getDroppedDims())),
433 op.getPadding(), op.getMask(), op.getInBoundsAttr());
435 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
437 op, op.getType(), subViewOp.getSource(), sourceIndices,
438 op.getLeadDimension(), op.getTransposeAttr());
440 .Case([&](nvgpu::LdMatrixOp op) {
442 op, op.getType(), subViewOp.getSource(), sourceIndices,
443 op.getTranspose(), op.getNumTiles());
445 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
449 template <
typename OpTy>
450 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
459 loadOp.getIndices().end());
462 if (
auto affineLoadOp =
463 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
464 AffineMap affineMap = affineLoadOp.getAffineMap();
466 affineMap, indices, loadOp.getLoc(), rewriter);
467 indices.assign(expandedIndices.begin(), expandedIndices.end());
471 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
474 .Case<affine::AffineLoadOp, memref::LoadOp>([&](
auto op) {
476 loadOp, expandShapeOp.getViewSource(), sourceIndices);
478 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
482 template <
typename OpTy>
483 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
486 .template getDefiningOp<memref::CollapseShapeOp>();
488 if (!collapseShapeOp)
492 loadOp.getIndices().end());
495 if (
auto affineLoadOp =
496 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
497 AffineMap affineMap = affineLoadOp.getAffineMap();
499 affineMap, indices, loadOp.getLoc(), rewriter);
500 indices.assign(expandedIndices.begin(), expandedIndices.end());
504 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
507 .Case<affine::AffineLoadOp, memref::LoadOp>([&](
auto op) {
509 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
511 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
515 template <
typename OpTy>
516 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
526 if (
failed(preconditionResult))
527 return preconditionResult;
530 storeOp.getIndices().end());
533 if (
auto affineStoreOp =
534 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
535 AffineMap affineMap = affineStoreOp.getAffineMap();
537 affineMap, indices, storeOp.getLoc(), rewriter);
538 indices.assign(expandedIndices.begin(), expandedIndices.end());
542 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
543 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
547 .Case([&](affine::AffineStoreOp op) {
549 op, op.getValue(), subViewOp.getSource(), sourceIndices);
551 .Case([&](memref::StoreOp op) {
553 op, op.getValue(), subViewOp.getSource(), sourceIndices,
554 op.getNontemporal());
556 .Case([&](vector::TransferWriteOp op) {
558 op, op.getValue(), subViewOp.getSource(), sourceIndices,
560 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
561 subViewOp.getDroppedDims())),
562 op.getMask(), op.getInBoundsAttr());
564 .Case([&](vector::StoreOp op) {
566 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
568 .Case([&](vector::MaskedStoreOp op) {
570 op, subViewOp.getSource(), sourceIndices, op.getMask(),
571 op.getValueToStore());
573 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
575 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
576 op.getLeadDimension(), op.getTransposeAttr());
578 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
582 template <
typename OpTy>
583 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
592 storeOp.getIndices().end());
595 if (
auto affineStoreOp =
596 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
597 AffineMap affineMap = affineStoreOp.getAffineMap();
599 affineMap, indices, storeOp.getLoc(), rewriter);
600 indices.assign(expandedIndices.begin(), expandedIndices.end());
604 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
607 .Case<affine::AffineStoreOp, memref::StoreOp>([&](
auto op) {
609 expandShapeOp.getViewSource(),
612 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
616 template <
typename OpTy>
617 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
620 .template getDefiningOp<memref::CollapseShapeOp>();
622 if (!collapseShapeOp)
626 storeOp.getIndices().end());
629 if (
auto affineStoreOp =
630 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
631 AffineMap affineMap = affineStoreOp.getAffineMap();
633 affineMap, indices, storeOp.getLoc(), rewriter);
634 indices.assign(expandedIndices.begin(), expandedIndices.end());
638 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
641 .Case<affine::AffineStoreOp, memref::StoreOp>([&](
auto op) {
643 storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
646 .Default([](
Operation *) { llvm_unreachable(
"unexpected operation."); });
650 LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite(
653 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
656 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
658 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
660 if (!(srcSubViewOp || dstSubViewOp))
662 "source or destination");
666 copyOp.getSrcIndices().end());
670 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
672 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
673 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
674 srcindices, foldedSrcIndices);
679 copyOp.getDstIndices().end());
683 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
685 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
686 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
687 dstindices, foldedDstIndices);
694 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
696 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
697 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
698 copyOp.getBypassL1Attr());
704 patterns.
add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
705 LoadOpOfSubViewOpFolder<memref::LoadOp>,
706 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
707 LoadOpOfSubViewOpFolder<vector::LoadOp>,
708 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
709 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
710 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
711 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
712 StoreOpOfSubViewOpFolder<memref::StoreOp>,
713 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
714 StoreOpOfSubViewOpFolder<vector::StoreOp>,
715 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
716 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
717 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
718 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
719 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
720 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
721 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
722 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
723 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
724 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
725 SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>(
735 struct FoldMemRefAliasOpsPass final
736 :
public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
737 void runOnOperation()
override;
742 void FoldMemRefAliasOpsPass::runOnOperation() {
749 return std::make_unique<FoldMemRefAliasOpsPass>();
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::unique_ptr< Pass > createFoldMemRefAliasOpsPass()
Creates an operation pass to fold memref aliasing ops into consumer load/store ops into patterns.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
void bindDimsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...