25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
31#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41struct StridedMetadata {
44 SmallVector<OpFoldResult> sizes;
45 SmallVector<OpFoldResult> strides;
59static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && ShapedType::isStatic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && ShapedType::isStatic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
192 PatternRewriter &rewriter)
const override {
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (
failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221struct ExtractStridedMetadataOpSubviewFolder
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
366 auto [strides, offset] = sourceType.getStridesAndOffset();
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize =
indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (ShapedType::isStatic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
506 auto [strides, offset] = sourceType.getStridesAndOffset();
511 for (
int64_t currentDim : reassocGroup) {
517 if (srcShape[currentDim] == 1)
520 int64_t currentStride = strides[currentDim];
521 lastValidStride = ShapedType::isDynamic(currentStride)
522 ? origStrides[currentDim]
525 if (!lastValidStride) {
528 MemRefType collapsedType = collapseShape.getResultType();
529 auto [collapsedStrides, collapsedOffset] =
530 collapsedType.getStridesAndOffset();
531 int64_t finalStride = collapsedStrides[groupId];
532 if (ShapedType::isDynamic(finalStride)) {
535 for (
int64_t currentDim : reassocGroup) {
536 assert(srcShape[currentDim] == 1 &&
537 "We should be dealing with 1x1x...x1");
539 if (ShapedType::isDynamic(strides[currentDim]))
540 return {origStrides[currentDim]};
542 llvm_unreachable(
"We should have found a dynamic stride");
547 return {lastValidStride};
560template <
typename ReassociativeReshapeLikeOp>
561static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
562 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
571 getReshapedStrides) {
574 Location origLoc = reshape.getLoc();
575 Value source = reshape.getSrc();
576 auto sourceType = cast<MemRefType>(source.
getType());
577 unsigned sourceRank = sourceType.getRank();
579 auto newExtractStridedMetadata =
580 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
583 auto [strides, offset] = sourceType.getStridesAndOffset();
584 MemRefType reshapeType = reshape.getResultType();
585 unsigned reshapeRank = reshapeType.getRank();
588 ShapedType::isDynamic(offset)
590 : rewriter.getIndexAttr(offset);
593 if (sourceRank == 0) {
595 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
600 finalSizes.reserve(reshapeRank);
602 finalStrides.reserve(reshapeRank);
609 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
610 for (; idx != endIdx; ++idx) {
612 getReshapedSizes(reshape, rewriter, origSizes, idx);
614 reshape, rewriter, origSizes, origStrides, idx);
616 unsigned groupSize = reshapedSizes.size();
617 for (
unsigned i = 0; i < groupSize; ++i) {
618 finalSizes.push_back(reshapedSizes[i]);
619 finalStrides.push_back(reshapedStrides[i]);
622 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
623 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
624 "We should have visited all the input dimensions");
625 assert(finalSizes.size() == reshapeRank &&
626 "We should have populated all the values");
628 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
629 finalSizes, finalStrides};
648template <
typename ReassociativeReshapeLikeOp,
660 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
662 FailureOr<StridedMetadata> stridedMetadata =
663 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
664 rewriter, reshape, getReshapedSizes, getReshapedStrides);
665 if (
failed(stridedMetadata)) {
667 "failed to resolve reshape metadata");
671 reshape, reshape.getType(), stridedMetadata->basePtr,
672 stridedMetadata->offset, stridedMetadata->sizes,
673 stridedMetadata->strides);
691struct ExtractStridedMetadataOpCollapseShapeFolder
695 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
697 auto collapseShapeOp =
698 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
699 if (!collapseShapeOp)
702 FailureOr<StridedMetadata> stridedMetadata =
703 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
704 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
705 if (
failed(stridedMetadata)) {
708 "failed to resolve metadata in terms of source collapse_shape op");
711 Location loc = collapseShapeOp.getLoc();
713 results.push_back(stridedMetadata->basePtr);
715 stridedMetadata->offset));
719 stridedMetadata->strides));
728struct ExtractStridedMetadataOpExpandShapeFolder
732 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
734 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
738 FailureOr<StridedMetadata> stridedMetadata =
739 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
740 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
741 if (
failed(stridedMetadata)) {
743 op,
"failed to resolve metadata in terms of source expand_shape op");
746 Location loc = expandShapeOp.getLoc();
748 results.push_back(stridedMetadata->basePtr);
750 stridedMetadata->offset));
754 stridedMetadata->strides));
774template <
typename AllocLikeOp>
775struct ExtractStridedMetadataOpAllocFolder
780 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
782 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
786 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
787 if (!memRefType.getLayout().isIdentity())
789 allocLikeOp,
"alloc-like operations should have been normalized");
792 int rank = memRefType.getRank();
795 ValueRange dynamic = allocLikeOp.getDynamicSizes();
798 unsigned dynamicPos = 0;
799 for (
int64_t size : memRefType.getShape()) {
800 if (ShapedType::isDynamic(size))
801 sizes.push_back(dynamic[dynamicPos++]);
809 unsigned symbolNumber = 0;
810 for (
int i = rank - 2; i >= 0; --i) {
812 assert(i + 1 + symbolNumber == sizes.size() &&
813 "The ArrayRef should encompass the last #symbolNumber sizes");
816 sizesInvolvedInStride);
821 results.reserve(rank * 2 + 2);
823 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
825 if (op.getBaseBuffer().use_empty()) {
826 results.push_back(
nullptr);
828 if (allocLikeOp.getType() == baseBufferType)
829 results.push_back(allocLikeOp);
831 results.push_back(memref::ReinterpretCastOp::create(
832 rewriter, loc, baseBufferType, allocLikeOp, offset,
865struct ExtractStridedMetadataOpGetGlobalFolder
868 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
871 PatternRewriter &rewriter)
const override {
872 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
876 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
877 if (!memRefType.getLayout().isIdentity()) {
880 "get-global operation result should have been normalized");
883 Location loc = op.getLoc();
884 int rank = memRefType.getRank();
887 ArrayRef<int64_t> sizes = memRefType.getShape();
888 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
889 "unexpected dynamic shape for result of `memref.get_global` op");
895 SmallVector<Value> results;
896 results.reserve(rank * 2 + 2);
898 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
900 if (getGlobalOp.getType() == baseBufferType)
901 results.push_back(getGlobalOp);
903 results.push_back(memref::ReinterpretCastOp::create(
904 rewriter, loc, baseBufferType, getGlobalOp, offset,
906 ArrayRef<int64_t>()));
911 for (
auto size : sizes)
914 for (
auto stride : strides)
933struct ExtractStridedMetadataOpAssumeAlignmentFolder
936 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
938 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939 PatternRewriter &rewriter)
const override {
940 auto assumeAlignmentOp =
941 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942 if (!assumeAlignmentOp)
946 op, assumeAlignmentOp.getViewSource());
953class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
958 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
959 PatternRewriter &rewriter)
const override {
961 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
965 if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
966 !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
969 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
988class ExtractStridedMetadataOpReinterpretCastFolder
993 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
994 PatternRewriter &rewriter)
const override {
995 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
996 .getDefiningOp<memref::ReinterpretCastOp>();
997 if (!reinterpretCastOp)
1000 Location loc = extractStridedMetadataOp.getLoc();
1002 SmallVector<Type> inferredReturnTypes;
1003 if (
failed(extractStridedMetadataOp.inferReturnTypes(
1004 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
1006 inferredReturnTypes)))
1008 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
1010 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
1011 unsigned rank = memrefType.getRank();
1012 SmallVector<OpFoldResult> results;
1013 results.resize_for_overwrite(rank * 2 + 2);
1015 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1016 rewriter, loc, reinterpretCastOp.getSource());
1019 results[0] = newExtractStridedMetadata.getBaseBuffer();
1023 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
1025 const unsigned sizeStartIdx = 2;
1026 const unsigned strideStartIdx = sizeStartIdx + rank;
1028 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
1029 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
1030 for (
unsigned i = 0; i < rank; ++i) {
1031 results[sizeStartIdx + i] = sizes[i];
1032 results[strideStartIdx + i] = strides[i];
1034 rewriter.
replaceOp(extractStridedMetadataOp,
1050class ExtractStridedMetadataOpMemorySpaceCastFolder
1055 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1056 PatternRewriter &rewriter)
const override {
1057 Location loc = extractStridedMetadataOp.getLoc();
1058 Value source = extractStridedMetadataOp.getSource();
1059 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
1060 if (!memSpaceCastOp)
1062 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1063 rewriter, loc, memSpaceCastOp.getSource());
1064 SmallVector<Value> results(newExtractStridedMetadata.getResults());
1071 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1072 auto baseBuffer = results[0];
1073 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1074 MemRefType::Builder newTypeBuilder(baseBufferType);
1075 newTypeBuilder.setMemorySpace(
1076 memSpaceCastOp.getResult().getType().getMemorySpace());
1077 results[0] = memref::MemorySpaceCastOp::create(
1078 rewriter, loc, Type{newTypeBuilder}, baseBuffer);
1080 results[0] =
nullptr;
1082 rewriter.
replaceOp(extractStridedMetadataOp, results);
1094class ExtractStridedMetadataOpExtractStridedMetadataFolder
1099 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1100 PatternRewriter &rewriter)
const override {
1101 auto sourceExtractStridedMetadataOp =
1102 extractStridedMetadataOp.getSource()
1103 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1104 if (!sourceExtractStridedMetadataOp)
1106 Location loc = extractStridedMetadataOp.getLoc();
1107 rewriter.
replaceOp(extractStridedMetadataOp,
1108 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1119 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1120 getExpandedStrides>,
1121 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1122 getCollapsedStride>,
1123 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1124 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1125 ExtractStridedMetadataOpCollapseShapeFolder,
1126 ExtractStridedMetadataOpExpandShapeFolder,
1127 ExtractStridedMetadataOpGetGlobalFolder,
1128 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1129 ExtractStridedMetadataOpReinterpretCastFolder,
1130 ExtractStridedMetadataOpSubviewFolder,
1131 ExtractStridedMetadataOpMemorySpaceCastFolder,
1132 ExtractStridedMetadataOpAssumeAlignmentFolder,
1133 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1139 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1140 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1141 ExtractStridedMetadataOpCollapseShapeFolder,
1142 ExtractStridedMetadataOpExpandShapeFolder,
1143 ExtractStridedMetadataOpGetGlobalFolder,
1144 ExtractStridedMetadataOpSubviewFolder,
1145 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1146 ExtractStridedMetadataOpReinterpretCastFolder,
1147 ExtractStridedMetadataOpMemorySpaceCastFolder,
1148 ExtractStridedMetadataOpAssumeAlignmentFolder,
1149 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1159struct ExpandStridedMetadataPass final
1161 ExpandStridedMetadataPass> {
1162 void runOnOperation()
override;
1167void ExpandStridedMetadataPass::runOnOperation() {
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::function_ref< Fn > function_ref
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.