27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallBitVector.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "fold-memref-alias-ops"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
39#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
57 reassocs.take_back(n),
64 return llvm::all_of(subview.getStaticStrides().take_back(n),
65 [](
int64_t s) { return s == 1; });
69template <
typename LoadOrStoreOpTy>
71 return op.getMemref();
79 return op.getSrcMemref();
100template <
typename OpTy>
103 using OpRewritePattern<OpTy>::OpRewritePattern;
105 LogicalResult matchAndRewrite(OpTy loadOp,
106 PatternRewriter &rewriter)
const override;
110template <
typename OpTy>
113 using OpRewritePattern<OpTy>::OpRewritePattern;
115 LogicalResult matchAndRewrite(OpTy loadOp,
116 PatternRewriter &rewriter)
const override;
120template <
typename OpTy>
123 using OpRewritePattern<OpTy>::OpRewritePattern;
125 LogicalResult matchAndRewrite(OpTy loadOp,
126 PatternRewriter &rewriter)
const override;
130template <
typename OpTy>
133 using OpRewritePattern<OpTy>::OpRewritePattern;
135 LogicalResult matchAndRewrite(OpTy storeOp,
136 PatternRewriter &rewriter)
const override;
140template <
typename OpTy>
143 using OpRewritePattern<OpTy>::OpRewritePattern;
145 LogicalResult matchAndRewrite(OpTy storeOp,
146 PatternRewriter &rewriter)
const override;
150template <
typename OpTy>
153 using OpRewritePattern<OpTy>::OpRewritePattern;
155 LogicalResult matchAndRewrite(OpTy storeOp,
156 PatternRewriter &rewriter)
const override;
162 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
164 LogicalResult matchAndRewrite(memref::SubViewOp subView,
165 PatternRewriter &rewriter)
const override {
166 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
170 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
172 rewriter, subView.getLoc(), srcSubView, subView,
173 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
178 subView, subView.getType(), srcSubView.getSource(), newOffsets,
179 newSizes, newStrides);
186class NVGPUAsyncCopyOpSubViewOpFolder final
189 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
191 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
192 PatternRewriter &rewriter)
const override;
198struct AccessOpOfSubViewOpFolder final
202 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
203 PatternRewriter &rewriter)
const override;
212struct AccessOpOfExpandShapeOpFolder final
216 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
217 PatternRewriter &rewriter)
const override;
228struct AccessOpOfCollapseShapeOpFolder final
232 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
233 PatternRewriter &rewriter)
const override;
241struct IndexedMemCopyOpOfSubViewOpFolder final
245 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
246 PatternRewriter &rewriter)
const override;
252struct IndexedMemCopyOpOfExpandShapeOpFolder final
256 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
257 PatternRewriter &rewriter)
const override;
263struct IndexedMemCopyOpOfCollapseShapeOpFolder final
267 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
268 PatternRewriter &rewriter)
const override;
272template <
typename XferOp>
275 memref::SubViewOp subviewOp) {
277 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
278 "must be a vector transfer op");
279 if (xferOp.hasOutOfBoundsDim())
281 if (!subviewOp.hasUnitStride()) {
283 xferOp,
"non-1 stride subview, need to track strides in folded memref");
290 memref::SubViewOp subviewOp) {
295 vector::TransferReadOp readOp,
296 memref::SubViewOp subviewOp) {
301 vector::TransferWriteOp writeOp,
302 memref::SubViewOp subviewOp) {
306template <
typename OpTy>
307LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
315 LogicalResult preconditionResult =
317 if (
failed(preconditionResult))
318 return preconditionResult;
322 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
323 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
324 loadOp.getIndices(), sourceIndices);
327 .Case([&](memref::LoadOp op) {
329 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
331 .Case([&](vector::LoadOp op) {
333 op, op.getType(), subViewOp.getSource(), sourceIndices);
335 .Case([&](vector::MaskedLoadOp op) {
337 op, op.getType(), subViewOp.getSource(), sourceIndices,
338 op.getMask(), op.getPassThru());
340 .Case([&](vector::TransferReadOp op) {
342 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
344 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
345 subViewOp.getDroppedDims())),
346 op.getPadding(), op.getMask(), op.getInBoundsAttr());
348 .Case([&](nvgpu::LdMatrixOp op) {
350 op, op.getType(), subViewOp.getSource(), sourceIndices,
351 op.getTranspose(), op.getNumTiles());
353 .DefaultUnreachable(
"unexpected operation");
357template <
typename OpTy>
358LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
371 if (
auto transferOp =
372 dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
373 const int64_t vectorRank = transferOp.getVectorType().getRank();
375 cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
376 if (sourceRank < vectorRank)
382 bool foundExpr =
false;
383 for (
auto reassocationIndices :
384 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
385 auto reassociation = reassocationIndices.value();
387 reassociation[reassociation.size() - 1], rewriter.
getContext());
390 reassocationIndices.index(), rewriter.
getContext()));
404 loadOp.getIndices(), sourceIndices,
405 isa<memref::LoadOp>(loadOp.getOperation()));
408 .Case([&](memref::LoadOp op) {
410 loadOp, expandShapeOp.getViewSource(), sourceIndices,
411 op.getNontemporal());
414 .Case([&](vector::LoadOp op) {
416 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
417 op.getNontemporal());
420 .Case([&](vector::MaskedLoadOp op) {
422 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
423 op.getMask(), op.getPassThru());
426 .Case([&](vector::TransferReadOp op) {
427 const int64_t sourceRank = sourceIndices.size();
428 auto newMap =
AffineMap::get(sourceRank, 0, transferReadNewResults,
431 op, op.getVectorType(), expandShapeOp.getViewSource(),
432 sourceIndices, newMap, op.getPadding(), op.getMask(),
436 .DefaultUnreachable(
"unexpected operation");
439template <
typename OpTy>
440LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
443 .template getDefiningOp<memref::CollapseShapeOp>();
445 if (!collapseShapeOp)
450 loadOp.getIndices(), sourceIndices);
452 .Case([&](memref::LoadOp op) {
454 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
455 op.getNontemporal());
457 .Case([&](vector::LoadOp op) {
459 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
460 op.getNontemporal());
462 .Case([&](vector::MaskedLoadOp op) {
464 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
465 op.getMask(), op.getPassThru());
467 .DefaultUnreachable(
"unexpected operation");
471template <
typename OpTy>
472LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
480 LogicalResult preconditionResult =
482 if (failed(preconditionResult))
483 return preconditionResult;
487 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
488 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
489 storeOp.getIndices(), sourceIndices);
492 .Case([&](memref::StoreOp op) {
494 op, op.getValue(), subViewOp.getSource(), sourceIndices,
495 op.getNontemporal());
497 .Case([&](vector::TransferWriteOp op) {
499 op, op.getValue(), subViewOp.getSource(), sourceIndices,
501 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
502 subViewOp.getDroppedDims())),
503 op.getMask(), op.getInBoundsAttr());
505 .Case([&](vector::StoreOp op) {
507 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
509 .Case([&](vector::MaskedStoreOp op) {
511 op, subViewOp.getSource(), sourceIndices, op.getMask(),
512 op.getValueToStore());
514 .DefaultUnreachable(
"unexpected operation");
518template <
typename OpTy>
519LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
531 storeOp.getIndices(), sourceIndices,
532 isa<memref::StoreOp>(storeOp.getOperation()));
534 .Case([&](memref::StoreOp op) {
536 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
537 sourceIndices, op.getNontemporal());
539 .Case([&](vector::StoreOp op) {
541 op, op.getValueToStore(), expandShapeOp.getViewSource(),
542 sourceIndices, op.getNontemporal());
544 .Case([&](vector::MaskedStoreOp op) {
546 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
547 op.getValueToStore());
549 .DefaultUnreachable(
"unexpected operation");
553template <
typename OpTy>
554LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
557 .template getDefiningOp<memref::CollapseShapeOp>();
559 if (!collapseShapeOp)
564 storeOp.getIndices(), sourceIndices);
566 .Case([&](memref::StoreOp op) {
568 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
569 sourceIndices, op.getNontemporal());
571 .Case([&](vector::StoreOp op) {
573 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
574 sourceIndices, op.getNontemporal());
576 .Case([&](vector::MaskedStoreOp op) {
578 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
579 op.getValueToStore());
581 .DefaultUnreachable(
"unexpected operation");
586AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
588 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
592 SmallVector<int64_t> accessedShape = op.getAccessedShape();
597 int64_t accessedDims = accessedShape.size();
600 op,
"non-unit stride on accessed dimensions");
602 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
603 int64_t sourceRank = subview.getSourceType().getRank();
608 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
609 if (secondAccessedDim < sourceRank) {
610 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
611 if (droppedDims.test(d))
613 op,
"reintroducing dropped dimension " + Twine(d) +
614 " would break access op semantics");
618 SmallVector<Value> sourceIndices;
620 rewriter, op.getLoc(), subview.getMixedOffsets(),
621 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
623 std::optional<SmallVector<Value>> newValues =
624 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
630LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
631 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
632 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
636 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
637 ArrayRef<int64_t> accessedShape = rawAccessedShape;
640 if (!accessedShape.empty())
641 accessedShape = accessedShape.drop_front();
643 SmallVector<ReassociationIndices, 4> reassocs =
644 expand.getReassociationIndices();
648 "expand_shape folding would merge semanvtically important dimensions");
650 SmallVector<Value> sourceIndices;
652 op.getIndices(), sourceIndices,
653 op.hasInboundsIndices());
655 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
656 rewriter, expand.getViewSource(), sourceIndices);
662LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
663 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
665 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
669 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
670 ArrayRef<int64_t> accessedShape = rawAccessedShape;
674 if (!accessedShape.empty())
675 accessedShape = accessedShape.drop_front();
677 SmallVector<ReassociationIndices, 4> reassocs =
678 collapse.getReassociationIndices();
681 "collapse_shape folding would merge "
682 "semanvtically important dimensions");
684 SmallVector<Value> sourceIndices;
686 op.getIndices(), sourceIndices);
688 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
689 rewriter, collapse.getViewSource(), sourceIndices);
695LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
696 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
697 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
698 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
699 if (!srcSubview && !dstSubview)
701 op,
"no subviews found on indexed copy inputs");
703 Value newSrc = op.getSrc();
704 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
705 Value newDst = op.getDst();
706 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
708 newSrc = srcSubview.getSource();
709 newSrcIndices.clear();
711 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
712 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
713 op.getSrcIndices(), newSrcIndices);
716 newDst = dstSubview.getSource();
717 newDstIndices.clear();
719 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
720 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
721 op.getDstIndices(), newDstIndices);
723 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
728LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
729 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
730 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
731 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
732 if (!srcExpand && !dstExpand)
734 op,
"no expand_shapes found on indexed copy inputs");
736 Value newSrc = op.getSrc();
737 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
738 Value newDst = op.getDst();
739 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
741 newSrc = srcExpand.getViewSource();
742 newSrcIndices.clear();
744 op.getSrcIndices(), newSrcIndices,
748 newDst = dstExpand.getViewSource();
749 newDstIndices.clear();
751 op.getDstIndices(), newDstIndices,
754 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
759LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
760 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
761 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
762 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
763 if (!srcCollapse && !dstCollapse)
765 op,
"no collapse_shapes found on indexed copy inputs");
767 Value newSrc = op.getSrc();
768 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
769 Value newDst = op.getDst();
770 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
772 newSrc = srcCollapse.getViewSource();
773 newSrcIndices.clear();
775 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
778 newDst = dstCollapse.getViewSource();
779 newDstIndices.clear();
781 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
783 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
788LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
789 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter)
const {
791 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
794 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
796 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
798 if (!(srcSubViewOp || dstSubViewOp))
800 "source or destination");
803 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
804 copyOp.getSrcIndices().end());
807 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
809 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
810 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
811 copyOp.getSrcIndices(), foldedSrcIndices);
815 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
816 copyOp.getDstIndices().end());
819 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
821 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
822 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
823 copyOp.getDstIndices(), foldedDstIndices);
829 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
830 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
832 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
833 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
834 copyOp.getBypassL1Attr());
842 AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
843 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
844 IndexedMemCopyOpOfExpandShapeOpFolder,
845 IndexedMemCopyOpOfCollapseShapeOpFolder,
847 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
848 LoadOpOfSubViewOpFolder<vector::LoadOp>,
849 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
850 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
851 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
852 StoreOpOfSubViewOpFolder<vector::StoreOp>,
853 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
854 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
855 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
856 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
857 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
858 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
859 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
860 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
861 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
862 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
863 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
873struct FoldMemRefAliasOpsPass final
875 void runOnOperation()
override;
880void FoldMemRefAliasOpsPass::runOnOperation() {
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...