28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallBitVector.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
34#define DEBUG_TYPE "fold-memref-alias-ops"
35#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
39#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
40#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
58 reassocs.take_back(n),
65 return llvm::all_of(subview.getStaticStrides().take_back(n),
66 [](
int64_t s) { return s == 1; });
70template <
typename LoadOrStoreOpTy>
72 return op.getMemref();
80 return op.getSrcMemref();
96 return op.getSrcMemref();
100 return op.getDstMemref();
109template <
typename OpTy>
112 using OpRewritePattern<OpTy>::OpRewritePattern;
114 LogicalResult matchAndRewrite(OpTy loadOp,
115 PatternRewriter &rewriter)
const override;
119template <
typename OpTy>
122 using OpRewritePattern<OpTy>::OpRewritePattern;
124 LogicalResult matchAndRewrite(OpTy loadOp,
125 PatternRewriter &rewriter)
const override;
129template <
typename OpTy>
132 using OpRewritePattern<OpTy>::OpRewritePattern;
134 LogicalResult matchAndRewrite(OpTy loadOp,
135 PatternRewriter &rewriter)
const override;
139template <
typename OpTy>
142 using OpRewritePattern<OpTy>::OpRewritePattern;
144 LogicalResult matchAndRewrite(OpTy storeOp,
145 PatternRewriter &rewriter)
const override;
149template <
typename OpTy>
152 using OpRewritePattern<OpTy>::OpRewritePattern;
154 LogicalResult matchAndRewrite(OpTy storeOp,
155 PatternRewriter &rewriter)
const override;
159template <
typename OpTy>
162 using OpRewritePattern<OpTy>::OpRewritePattern;
164 LogicalResult matchAndRewrite(OpTy storeOp,
165 PatternRewriter &rewriter)
const override;
171 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
173 LogicalResult matchAndRewrite(memref::SubViewOp subView,
174 PatternRewriter &rewriter)
const override {
175 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
179 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
180 if (
failed(affine::mergeOffsetsSizesAndStrides(
181 rewriter, subView.getLoc(), srcSubView, subView,
182 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
187 subView, subView.getType(), srcSubView.getSource(), newOffsets,
188 newSizes, newStrides);
195class NVGPUAsyncCopyOpSubViewOpFolder final
198 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
200 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
201 PatternRewriter &rewriter)
const override;
207struct AccessOpOfSubViewOpFolder final
211 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
212 PatternRewriter &rewriter)
const override;
221struct AccessOpOfExpandShapeOpFolder final
225 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
226 PatternRewriter &rewriter)
const override;
237struct AccessOpOfCollapseShapeOpFolder final
241 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
242 PatternRewriter &rewriter)
const override;
250struct IndexedMemCopyOpOfSubViewOpFolder final
254 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
255 PatternRewriter &rewriter)
const override;
261struct IndexedMemCopyOpOfExpandShapeOpFolder final
265 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
266 PatternRewriter &rewriter)
const override;
272struct IndexedMemCopyOpOfCollapseShapeOpFolder final
276 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
277 PatternRewriter &rewriter)
const override;
281template <
typename XferOp>
284 memref::SubViewOp subviewOp) {
286 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
287 "must be a vector transfer op");
288 if (xferOp.hasOutOfBoundsDim())
290 if (!subviewOp.hasUnitStride()) {
292 xferOp,
"non-1 stride subview, need to track strides in folded memref");
299 memref::SubViewOp subviewOp) {
304 vector::TransferReadOp readOp,
305 memref::SubViewOp subviewOp) {
310 vector::TransferWriteOp writeOp,
311 memref::SubViewOp subviewOp) {
315template <
typename OpTy>
316LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
324 LogicalResult preconditionResult =
326 if (
failed(preconditionResult))
327 return preconditionResult;
331 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
332 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
333 loadOp.getIndices(), sourceIndices);
336 .Case([&](memref::LoadOp op) {
338 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
340 .Case([&](vector::LoadOp op) {
342 op, op.getType(), subViewOp.getSource(), sourceIndices);
344 .Case([&](vector::MaskedLoadOp op) {
346 op, op.getType(), subViewOp.getSource(), sourceIndices,
347 op.getMask(), op.getPassThru());
349 .Case([&](vector::TransferReadOp op) {
351 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
353 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
354 subViewOp.getDroppedDims())),
355 op.getPadding(), op.getMask(), op.getInBoundsAttr());
357 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
359 op, op.getType(), subViewOp.getSource(), sourceIndices,
360 op.getLeadDimension(), op.getTransposeAttr());
362 .Case([&](nvgpu::LdMatrixOp op) {
364 op, op.getType(), subViewOp.getSource(), sourceIndices,
365 op.getTranspose(), op.getNumTiles());
367 .DefaultUnreachable(
"unexpected operation");
371template <
typename OpTy>
372LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
385 if (
auto transferOp =
386 dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
387 const int64_t vectorRank = transferOp.getVectorType().getRank();
389 cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
390 if (sourceRank < vectorRank)
396 bool foundExpr =
false;
397 for (
auto reassocationIndices :
398 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
399 auto reassociation = reassocationIndices.value();
401 reassociation[reassociation.size() - 1], rewriter.
getContext());
404 reassocationIndices.index(), rewriter.
getContext()));
418 loadOp.getIndices(), sourceIndices,
419 isa<memref::LoadOp>(loadOp.getOperation()));
422 .Case([&](memref::LoadOp op) {
424 loadOp, expandShapeOp.getViewSource(), sourceIndices,
425 op.getNontemporal());
428 .Case([&](vector::LoadOp op) {
430 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
431 op.getNontemporal());
434 .Case([&](vector::MaskedLoadOp op) {
436 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
437 op.getMask(), op.getPassThru());
440 .Case([&](vector::TransferReadOp op) {
441 const int64_t sourceRank = sourceIndices.size();
442 auto newMap =
AffineMap::get(sourceRank, 0, transferReadNewResults,
445 op, op.getVectorType(), expandShapeOp.getViewSource(),
446 sourceIndices, newMap, op.getPadding(), op.getMask(),
450 .DefaultUnreachable(
"unexpected operation");
453template <
typename OpTy>
454LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
457 .template getDefiningOp<memref::CollapseShapeOp>();
459 if (!collapseShapeOp)
464 loadOp.getIndices(), sourceIndices);
466 .Case([&](memref::LoadOp op) {
468 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
469 op.getNontemporal());
471 .Case([&](vector::LoadOp op) {
473 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
474 op.getNontemporal());
476 .Case([&](vector::MaskedLoadOp op) {
478 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
479 op.getMask(), op.getPassThru());
481 .DefaultUnreachable(
"unexpected operation");
485template <
typename OpTy>
486LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
494 LogicalResult preconditionResult =
496 if (
failed(preconditionResult))
497 return preconditionResult;
501 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
502 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
503 storeOp.getIndices(), sourceIndices);
506 .Case([&](memref::StoreOp op) {
508 op, op.getValue(), subViewOp.getSource(), sourceIndices,
509 op.getNontemporal());
511 .Case([&](vector::TransferWriteOp op) {
513 op, op.getValue(), subViewOp.getSource(), sourceIndices,
515 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
516 subViewOp.getDroppedDims())),
517 op.getMask(), op.getInBoundsAttr());
519 .Case([&](vector::StoreOp op) {
521 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
523 .Case([&](vector::MaskedStoreOp op) {
525 op, subViewOp.getSource(), sourceIndices, op.getMask(),
526 op.getValueToStore());
528 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
530 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
531 op.getLeadDimension(), op.getTransposeAttr());
533 .DefaultUnreachable(
"unexpected operation");
537template <
typename OpTy>
538LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
550 storeOp.getIndices(), sourceIndices,
551 isa<memref::StoreOp>(storeOp.getOperation()));
553 .Case([&](memref::StoreOp op) {
555 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
556 sourceIndices, op.getNontemporal());
558 .Case([&](vector::StoreOp op) {
560 op, op.getValueToStore(), expandShapeOp.getViewSource(),
561 sourceIndices, op.getNontemporal());
563 .Case([&](vector::MaskedStoreOp op) {
565 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
566 op.getValueToStore());
568 .DefaultUnreachable(
"unexpected operation");
572template <
typename OpTy>
573LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
576 .template getDefiningOp<memref::CollapseShapeOp>();
578 if (!collapseShapeOp)
583 storeOp.getIndices(), sourceIndices);
585 .Case([&](memref::StoreOp op) {
587 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
588 sourceIndices, op.getNontemporal());
590 .Case([&](vector::StoreOp op) {
592 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
593 sourceIndices, op.getNontemporal());
595 .Case([&](vector::MaskedStoreOp op) {
597 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
598 op.getValueToStore());
600 .DefaultUnreachable(
"unexpected operation");
605AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
607 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
611 SmallVector<int64_t> accessedShape = op.getAccessedShape();
616 int64_t accessedDims = accessedShape.size();
619 op,
"non-unit stride on accessed dimensions");
621 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
622 int64_t sourceRank = subview.getSourceType().getRank();
627 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
628 if (secondAccessedDim < sourceRank) {
629 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
630 if (droppedDims.test(d))
632 op,
"reintroducing dropped dimension " + Twine(d) +
633 " would break access op semantics");
637 SmallVector<Value> sourceIndices;
638 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
639 rewriter, op.getLoc(), subview.getMixedOffsets(),
640 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
642 std::optional<SmallVector<Value>> newValues =
643 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
649LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
650 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
651 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
655 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
656 ArrayRef<int64_t> accessedShape = rawAccessedShape;
659 if (!accessedShape.empty())
660 accessedShape = accessedShape.drop_front();
662 SmallVector<ReassociationIndices, 4> reassocs =
663 expand.getReassociationIndices();
667 "expand_shape folding would merge semanvtically important dimensions");
669 SmallVector<Value> sourceIndices;
671 op.getIndices(), sourceIndices,
672 op.hasInboundsIndices());
674 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
675 rewriter, expand.getViewSource(), sourceIndices);
681LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
682 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
684 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
688 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
689 ArrayRef<int64_t> accessedShape = rawAccessedShape;
693 if (!accessedShape.empty())
694 accessedShape = accessedShape.drop_front();
696 SmallVector<ReassociationIndices, 4> reassocs =
697 collapse.getReassociationIndices();
700 "collapse_shape folding would merge "
701 "semanvtically important dimensions");
703 SmallVector<Value> sourceIndices;
705 op.getIndices(), sourceIndices);
707 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
708 rewriter, collapse.getViewSource(), sourceIndices);
714LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
715 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
716 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
717 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
718 if (!srcSubview && !dstSubview)
720 op,
"no subviews found on indexed copy inputs");
722 Value newSrc = op.getSrc();
723 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
724 Value newDst = op.getDst();
725 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
727 newSrc = srcSubview.getSource();
728 newSrcIndices.clear();
729 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
730 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
731 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
732 op.getSrcIndices(), newSrcIndices);
735 newDst = dstSubview.getSource();
736 newDstIndices.clear();
737 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
738 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
739 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
740 op.getDstIndices(), newDstIndices);
742 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
747LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
748 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
749 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
750 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
751 if (!srcExpand && !dstExpand)
753 op,
"no expand_shapes found on indexed copy inputs");
755 Value newSrc = op.getSrc();
756 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
757 Value newDst = op.getDst();
758 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
760 newSrc = srcExpand.getViewSource();
761 newSrcIndices.clear();
763 op.getSrcIndices(), newSrcIndices,
767 newDst = dstExpand.getViewSource();
768 newDstIndices.clear();
770 op.getDstIndices(), newDstIndices,
773 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
778LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
779 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
780 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
781 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
782 if (!srcCollapse && !dstCollapse)
784 op,
"no collapse_shapes found on indexed copy inputs");
786 Value newSrc = op.getSrc();
787 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
788 Value newDst = op.getDst();
789 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
791 newSrc = srcCollapse.getViewSource();
792 newSrcIndices.clear();
794 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
797 newDst = dstCollapse.getViewSource();
798 newDstIndices.clear();
800 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
802 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
807LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
808 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter)
const {
810 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
813 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
815 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
817 if (!(srcSubViewOp || dstSubViewOp))
819 "source or destination");
822 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
823 copyOp.getSrcIndices().end());
826 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
827 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
828 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
829 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
830 copyOp.getSrcIndices(), foldedSrcIndices);
834 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
835 copyOp.getDstIndices().end());
838 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
839 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
840 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
841 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
842 copyOp.getDstIndices(), foldedDstIndices);
848 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
849 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
851 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
852 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
853 copyOp.getBypassL1Attr());
861 AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
862 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
863 IndexedMemCopyOpOfExpandShapeOpFolder,
864 IndexedMemCopyOpOfCollapseShapeOpFolder,
866 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
867 LoadOpOfSubViewOpFolder<vector::LoadOp>,
868 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
869 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
870 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
871 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
872 StoreOpOfSubViewOpFolder<vector::StoreOp>,
873 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
874 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
875 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
876 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
877 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
878 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
879 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
880 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
881 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
882 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
883 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
884 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
894struct FoldMemRefAliasOpsPass final
895 :
public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
896 void runOnOperation()
override;
901void FoldMemRefAliasOpsPass::runOnOperation() {
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...