28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallBitVector.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
34#define DEBUG_TYPE "fold-memref-alias-ops"
35#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
39#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
40#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
58 reassocs.take_back(n),
65 return llvm::all_of(subview.getStaticStrides().take_back(n),
66 [](
int64_t s) { return s == 1; });
70template <
typename LoadOrStoreOpTy>
72 return op.getMemref();
80 return op.getSrcMemref();
96 return op.getSrcMemref();
100 return op.getDstMemref();
109template <
typename OpTy>
112 using OpRewritePattern<OpTy>::OpRewritePattern;
114 LogicalResult matchAndRewrite(OpTy loadOp,
115 PatternRewriter &rewriter)
const override;
119template <
typename OpTy>
122 using OpRewritePattern<OpTy>::OpRewritePattern;
124 LogicalResult matchAndRewrite(OpTy loadOp,
125 PatternRewriter &rewriter)
const override;
129template <
typename OpTy>
132 using OpRewritePattern<OpTy>::OpRewritePattern;
134 LogicalResult matchAndRewrite(OpTy loadOp,
135 PatternRewriter &rewriter)
const override;
139template <
typename OpTy>
142 using OpRewritePattern<OpTy>::OpRewritePattern;
144 LogicalResult matchAndRewrite(OpTy storeOp,
145 PatternRewriter &rewriter)
const override;
149template <
typename OpTy>
152 using OpRewritePattern<OpTy>::OpRewritePattern;
154 LogicalResult matchAndRewrite(OpTy storeOp,
155 PatternRewriter &rewriter)
const override;
159template <
typename OpTy>
162 using OpRewritePattern<OpTy>::OpRewritePattern;
164 LogicalResult matchAndRewrite(OpTy storeOp,
165 PatternRewriter &rewriter)
const override;
171 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
173 LogicalResult matchAndRewrite(memref::SubViewOp subView,
174 PatternRewriter &rewriter)
const override {
175 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
180 if (!subView.hasUnitStride()) {
183 if (!srcSubView.hasUnitStride()) {
188 SmallVector<OpFoldResult> resolvedSizes;
189 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
190 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
191 subView.getMixedSizes(), srcDroppedDims,
195 SmallVector<Value> resolvedOffsets;
196 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
197 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
198 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
203 subView, subView.getType(), srcSubView.getSource(),
205 srcSubView.getMixedStrides());
213class NVGPUAsyncCopyOpSubViewOpFolder final
216 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
218 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
219 PatternRewriter &rewriter)
const override;
225struct AccessOpOfSubViewOpFolder final
229 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
230 PatternRewriter &rewriter)
const override;
239struct AccessOpOfExpandShapeOpFolder final
243 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
244 PatternRewriter &rewriter)
const override;
255struct AccessOpOfCollapseShapeOpFolder final
259 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
260 PatternRewriter &rewriter)
const override;
268struct IndexedMemCopyOpOfSubViewOpFolder final
272 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
273 PatternRewriter &rewriter)
const override;
279struct IndexedMemCopyOpOfExpandShapeOpFolder final
283 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
284 PatternRewriter &rewriter)
const override;
290struct IndexedMemCopyOpOfCollapseShapeOpFolder final
294 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
295 PatternRewriter &rewriter)
const override;
299template <
typename XferOp>
302 memref::SubViewOp subviewOp) {
304 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
305 "must be a vector transfer op");
306 if (xferOp.hasOutOfBoundsDim())
308 if (!subviewOp.hasUnitStride()) {
310 xferOp,
"non-1 stride subview, need to track strides in folded memref");
317 memref::SubViewOp subviewOp) {
322 vector::TransferReadOp readOp,
323 memref::SubViewOp subviewOp) {
328 vector::TransferWriteOp writeOp,
329 memref::SubViewOp subviewOp) {
333template <
typename OpTy>
334LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
342 LogicalResult preconditionResult =
344 if (
failed(preconditionResult))
345 return preconditionResult;
349 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
350 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
351 loadOp.getIndices(), sourceIndices);
354 .Case([&](memref::LoadOp op) {
356 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
358 .Case([&](vector::LoadOp op) {
360 op, op.getType(), subViewOp.getSource(), sourceIndices);
362 .Case([&](vector::MaskedLoadOp op) {
364 op, op.getType(), subViewOp.getSource(), sourceIndices,
365 op.getMask(), op.getPassThru());
367 .Case([&](vector::TransferReadOp op) {
369 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
371 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
372 subViewOp.getDroppedDims())),
373 op.getPadding(), op.getMask(), op.getInBoundsAttr());
375 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
377 op, op.getType(), subViewOp.getSource(), sourceIndices,
378 op.getLeadDimension(), op.getTransposeAttr());
380 .Case([&](nvgpu::LdMatrixOp op) {
382 op, op.getType(), subViewOp.getSource(), sourceIndices,
383 op.getTranspose(), op.getNumTiles());
385 .DefaultUnreachable(
"unexpected operation");
389template <
typename OpTy>
390LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
403 if (
auto transferOp =
404 dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
405 const int64_t vectorRank = transferOp.getVectorType().getRank();
407 cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
408 if (sourceRank < vectorRank)
414 bool foundExpr =
false;
415 for (
auto reassocationIndices :
416 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
417 auto reassociation = reassocationIndices.value();
419 reassociation[reassociation.size() - 1], rewriter.
getContext());
422 reassocationIndices.index(), rewriter.
getContext()));
436 loadOp.getIndices(), sourceIndices,
437 isa<memref::LoadOp>(loadOp.getOperation()));
440 .Case([&](memref::LoadOp op) {
442 loadOp, expandShapeOp.getViewSource(), sourceIndices,
443 op.getNontemporal());
446 .Case([&](vector::LoadOp op) {
448 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
449 op.getNontemporal());
452 .Case([&](vector::MaskedLoadOp op) {
454 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
455 op.getMask(), op.getPassThru());
458 .Case([&](vector::TransferReadOp op) {
459 const int64_t sourceRank = sourceIndices.size();
460 auto newMap =
AffineMap::get(sourceRank, 0, transferReadNewResults,
463 op, op.getVectorType(), expandShapeOp.getViewSource(),
464 sourceIndices, newMap, op.getPadding(), op.getMask(),
468 .DefaultUnreachable(
"unexpected operation");
471template <
typename OpTy>
472LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
475 .template getDefiningOp<memref::CollapseShapeOp>();
477 if (!collapseShapeOp)
482 loadOp.getIndices(), sourceIndices);
484 .Case([&](memref::LoadOp op) {
486 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
487 op.getNontemporal());
489 .Case([&](vector::LoadOp op) {
491 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
492 op.getNontemporal());
494 .Case([&](vector::MaskedLoadOp op) {
496 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
497 op.getMask(), op.getPassThru());
499 .DefaultUnreachable(
"unexpected operation");
503template <
typename OpTy>
504LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
512 LogicalResult preconditionResult =
514 if (
failed(preconditionResult))
515 return preconditionResult;
519 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
520 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
521 storeOp.getIndices(), sourceIndices);
524 .Case([&](memref::StoreOp op) {
526 op, op.getValue(), subViewOp.getSource(), sourceIndices,
527 op.getNontemporal());
529 .Case([&](vector::TransferWriteOp op) {
531 op, op.getValue(), subViewOp.getSource(), sourceIndices,
533 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
534 subViewOp.getDroppedDims())),
535 op.getMask(), op.getInBoundsAttr());
537 .Case([&](vector::StoreOp op) {
539 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
541 .Case([&](vector::MaskedStoreOp op) {
543 op, subViewOp.getSource(), sourceIndices, op.getMask(),
544 op.getValueToStore());
546 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
548 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
549 op.getLeadDimension(), op.getTransposeAttr());
551 .DefaultUnreachable(
"unexpected operation");
555template <
typename OpTy>
556LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
568 storeOp.getIndices(), sourceIndices,
569 isa<memref::StoreOp>(storeOp.getOperation()));
571 .Case([&](memref::StoreOp op) {
573 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
574 sourceIndices, op.getNontemporal());
576 .Case([&](vector::StoreOp op) {
578 op, op.getValueToStore(), expandShapeOp.getViewSource(),
579 sourceIndices, op.getNontemporal());
581 .Case([&](vector::MaskedStoreOp op) {
583 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
584 op.getValueToStore());
586 .DefaultUnreachable(
"unexpected operation");
590template <
typename OpTy>
591LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
594 .template getDefiningOp<memref::CollapseShapeOp>();
596 if (!collapseShapeOp)
601 storeOp.getIndices(), sourceIndices);
603 .Case([&](memref::StoreOp op) {
605 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
606 sourceIndices, op.getNontemporal());
608 .Case([&](vector::StoreOp op) {
610 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
611 sourceIndices, op.getNontemporal());
613 .Case([&](vector::MaskedStoreOp op) {
615 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
616 op.getValueToStore());
618 .DefaultUnreachable(
"unexpected operation");
623AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
625 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
629 SmallVector<int64_t> accessedShape = op.getAccessedShape();
634 int64_t accessedDims = accessedShape.size();
637 op,
"non-unit stride on accessed dimensions");
639 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
640 int64_t sourceRank = subview.getSourceType().getRank();
645 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
646 if (secondAccessedDim < sourceRank) {
647 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
648 if (droppedDims.test(d))
650 op,
"reintroducing dropped dimension " + Twine(d) +
651 " would break access op semantics");
655 SmallVector<Value> sourceIndices;
656 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
657 rewriter, op.getLoc(), subview.getMixedOffsets(),
658 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
660 std::optional<SmallVector<Value>> newValues =
661 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
667LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
668 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
669 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
673 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
674 ArrayRef<int64_t> accessedShape = rawAccessedShape;
677 if (!accessedShape.empty())
678 accessedShape = accessedShape.drop_front();
680 SmallVector<ReassociationIndices, 4> reassocs =
681 expand.getReassociationIndices();
685 "expand_shape folding would merge semanvtically important dimensions");
687 SmallVector<Value> sourceIndices;
689 op.getIndices(), sourceIndices,
690 op.hasInboundsIndices());
692 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
693 rewriter, expand.getViewSource(), sourceIndices);
699LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
700 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
702 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
706 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
707 ArrayRef<int64_t> accessedShape = rawAccessedShape;
711 if (!accessedShape.empty())
712 accessedShape = accessedShape.drop_front();
714 SmallVector<ReassociationIndices, 4> reassocs =
715 collapse.getReassociationIndices();
718 "collapse_shape folding would merge "
719 "semanvtically important dimensions");
721 SmallVector<Value> sourceIndices;
723 op.getIndices(), sourceIndices);
725 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
726 rewriter, collapse.getViewSource(), sourceIndices);
732LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
733 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
734 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
735 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
736 if (!srcSubview && !dstSubview)
738 op,
"no subviews found on indexed copy inputs");
740 Value newSrc = op.getSrc();
741 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
742 Value newDst = op.getDst();
743 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
745 newSrc = srcSubview.getSource();
746 newSrcIndices.clear();
747 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
748 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
749 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
750 op.getSrcIndices(), newSrcIndices);
753 newDst = dstSubview.getSource();
754 newDstIndices.clear();
755 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
756 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
757 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
758 op.getDstIndices(), newDstIndices);
760 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
765LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
766 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
767 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
768 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
769 if (!srcExpand && !dstExpand)
771 op,
"no expand_shapes found on indexed copy inputs");
773 Value newSrc = op.getSrc();
774 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
775 Value newDst = op.getDst();
776 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
778 newSrc = srcExpand.getViewSource();
779 newSrcIndices.clear();
781 op.getSrcIndices(), newSrcIndices,
785 newDst = dstExpand.getViewSource();
786 newDstIndices.clear();
788 op.getDstIndices(), newDstIndices,
791 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
796LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
797 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
798 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
799 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
800 if (!srcCollapse && !dstCollapse)
802 op,
"no collapse_shapes found on indexed copy inputs");
804 Value newSrc = op.getSrc();
805 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
806 Value newDst = op.getDst();
807 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
809 newSrc = srcCollapse.getViewSource();
810 newSrcIndices.clear();
812 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
815 newDst = dstCollapse.getViewSource();
816 newDstIndices.clear();
818 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
820 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
825LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
826 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter)
const {
828 LLVM_DEBUG(
DBGS() <<
"copyOp : " << copyOp <<
"\n");
831 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
833 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
835 if (!(srcSubViewOp || dstSubViewOp))
837 "source or destination");
840 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
841 copyOp.getSrcIndices().end());
844 LLVM_DEBUG(
DBGS() <<
"srcSubViewOp : " << srcSubViewOp <<
"\n");
845 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
846 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
847 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
848 copyOp.getSrcIndices(), foldedSrcIndices);
852 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
853 copyOp.getDstIndices().end());
856 LLVM_DEBUG(
DBGS() <<
"dstSubViewOp : " << dstSubViewOp <<
"\n");
857 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
858 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
859 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
860 copyOp.getDstIndices(), foldedDstIndices);
866 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
867 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
869 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
870 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
871 copyOp.getBypassL1Attr());
879 AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
880 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
881 IndexedMemCopyOpOfExpandShapeOpFolder,
882 IndexedMemCopyOpOfCollapseShapeOpFolder,
884 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
885 LoadOpOfSubViewOpFolder<vector::LoadOp>,
886 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
887 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
888 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
889 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
890 StoreOpOfSubViewOpFolder<vector::StoreOp>,
891 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
892 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
893 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
894 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
895 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
896 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
897 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
898 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
899 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
900 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
901 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
902 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
912struct FoldMemRefAliasOpsPass final
913 :
public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
914 void runOnOperation()
override;
919void FoldMemRefAliasOpsPass::runOnOperation() {
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...