25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
31#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41struct StridedMetadata {
44 SmallVector<OpFoldResult> sizes;
45 SmallVector<OpFoldResult> strides;
59static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && ShapedType::isStatic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && ShapedType::isStatic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
192 PatternRewriter &rewriter)
const override {
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (
failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221struct ExtractStridedMetadataOpSubviewFolder
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
366 auto [strides, offset] = sourceType.getStridesAndOffset();
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize =
indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (ShapedType::isStatic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
506 auto [strides, offset] = sourceType.getStridesAndOffset();
511 for (
int64_t currentDim : reassocGroup) {
517 if (srcShape[currentDim] == 1)
520 int64_t currentStride = strides[currentDim];
521 lastValidStride = ShapedType::isDynamic(currentStride)
522 ? origStrides[currentDim]
525 if (!lastValidStride) {
528 MemRefType collapsedType = collapseShape.getResultType();
529 auto [collapsedStrides, collapsedOffset] =
530 collapsedType.getStridesAndOffset();
531 int64_t finalStride = collapsedStrides[groupId];
532 if (ShapedType::isDynamic(finalStride)) {
535 for (
int64_t currentDim : reassocGroup) {
536 assert(srcShape[currentDim] == 1 &&
537 "We should be dealing with 1x1x...x1");
539 if (ShapedType::isDynamic(strides[currentDim]))
540 return {origStrides[currentDim]};
542 llvm_unreachable(
"We should have found a dynamic stride");
547 return {lastValidStride};
560template <
typename ReassociativeReshapeLikeOp>
561static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
562 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
571 getReshapedStrides) {
574 Location origLoc = reshape.getLoc();
575 Value source = reshape.getSrc();
576 auto sourceType = cast<MemRefType>(source.
getType());
577 unsigned sourceRank = sourceType.getRank();
579 auto newExtractStridedMetadata =
580 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
583 auto [strides, offset] = sourceType.getStridesAndOffset();
584 MemRefType reshapeType = reshape.getResultType();
585 unsigned reshapeRank = reshapeType.getRank();
588 ShapedType::isDynamic(offset)
590 : rewriter.getIndexAttr(offset);
593 if (sourceRank == 0) {
595 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
600 finalSizes.reserve(reshapeRank);
602 finalStrides.reserve(reshapeRank);
609 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
610 for (; idx != endIdx; ++idx) {
612 getReshapedSizes(reshape, rewriter, origSizes, idx);
614 reshape, rewriter, origSizes, origStrides, idx);
616 unsigned groupSize = reshapedSizes.size();
617 for (
unsigned i = 0; i < groupSize; ++i) {
618 finalSizes.push_back(reshapedSizes[i]);
619 finalStrides.push_back(reshapedStrides[i]);
622 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
623 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
624 "We should have visited all the input dimensions");
625 assert(finalSizes.size() == reshapeRank &&
626 "We should have populated all the values");
628 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
629 finalSizes, finalStrides};
648template <
typename ReassociativeReshapeLikeOp,
660 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
662 FailureOr<StridedMetadata> stridedMetadata =
663 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
664 rewriter, reshape, getReshapedSizes, getReshapedStrides);
665 if (
failed(stridedMetadata)) {
667 "failed to resolve reshape metadata");
671 reshape, reshape.getType(), stridedMetadata->basePtr,
672 stridedMetadata->offset, stridedMetadata->sizes,
673 stridedMetadata->strides);
691struct ExtractStridedMetadataOpCollapseShapeFolder
695 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
697 auto collapseShapeOp =
698 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
699 if (!collapseShapeOp)
702 FailureOr<StridedMetadata> stridedMetadata =
703 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
704 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
705 if (
failed(stridedMetadata)) {
708 "failed to resolve metadata in terms of source collapse_shape op");
711 Location loc = collapseShapeOp.getLoc();
713 results.push_back(stridedMetadata->basePtr);
715 stridedMetadata->offset));
719 stridedMetadata->strides));
728struct ExtractStridedMetadataOpExpandShapeFolder
732 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
734 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
738 FailureOr<StridedMetadata> stridedMetadata =
739 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
740 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
741 if (
failed(stridedMetadata)) {
743 op,
"failed to resolve metadata in terms of source expand_shape op");
746 Location loc = expandShapeOp.getLoc();
748 results.push_back(stridedMetadata->basePtr);
750 stridedMetadata->offset));
754 stridedMetadata->strides));
774template <
typename AllocLikeOp>
775struct ExtractStridedMetadataOpAllocFolder
780 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
782 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
786 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
787 if (!memRefType.getLayout().isIdentity())
789 allocLikeOp,
"alloc-like operations should have been normalized");
792 int rank = memRefType.getRank();
795 ValueRange dynamic = allocLikeOp.getDynamicSizes();
798 unsigned dynamicPos = 0;
799 for (
int64_t size : memRefType.getShape()) {
800 if (ShapedType::isDynamic(size))
801 sizes.push_back(dynamic[dynamicPos++]);
809 unsigned symbolNumber = 0;
810 for (
int i = rank - 2; i >= 0; --i) {
812 assert(i + 1 + symbolNumber == sizes.size() &&
813 "The ArrayRef should encompass the last #symbolNumber sizes");
816 sizesInvolvedInStride);
821 results.reserve(rank * 2 + 2);
823 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
825 if (op.getBaseBuffer().use_empty()) {
826 results.push_back(
nullptr);
828 if (allocLikeOp.getType() == baseBufferType)
829 results.push_back(allocLikeOp);
831 results.push_back(memref::ReinterpretCastOp::create(
832 rewriter, loc, baseBufferType, allocLikeOp, offset,
865struct ExtractStridedMetadataOpGetGlobalFolder
868 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
871 PatternRewriter &rewriter)
const override {
872 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
876 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
877 if (!memRefType.getLayout().isIdentity()) {
880 "get-global operation result should have been normalized");
883 Location loc = op.getLoc();
884 int rank = memRefType.getRank();
887 ArrayRef<int64_t> sizes = memRefType.getShape();
888 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
889 "unexpected dynamic shape for result of `memref.get_global` op");
895 SmallVector<Value> results;
896 results.reserve(rank * 2 + 2);
898 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
900 if (getGlobalOp.getType() == baseBufferType)
901 results.push_back(getGlobalOp);
903 results.push_back(memref::ReinterpretCastOp::create(
904 rewriter, loc, baseBufferType, getGlobalOp, offset,
906 ArrayRef<int64_t>()));
911 for (
auto size : sizes)
914 for (
auto stride : strides)
933struct ExtractStridedMetadataOpAssumeAlignmentFolder
936 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
938 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939 PatternRewriter &rewriter)
const override {
940 auto assumeAlignmentOp =
941 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942 if (!assumeAlignmentOp)
946 op, assumeAlignmentOp.getViewSource());
953class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
958 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
959 PatternRewriter &rewriter)
const override {
961 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
962 if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
965 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
984class ExtractStridedMetadataOpReinterpretCastFolder
989 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
990 PatternRewriter &rewriter)
const override {
991 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
992 .getDefiningOp<memref::ReinterpretCastOp>();
993 if (!reinterpretCastOp)
996 Location loc = extractStridedMetadataOp.getLoc();
998 SmallVector<Type> inferredReturnTypes;
999 if (
failed(extractStridedMetadataOp.inferReturnTypes(
1000 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
1002 inferredReturnTypes)))
1004 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
1006 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
1007 unsigned rank = memrefType.getRank();
1008 SmallVector<OpFoldResult> results;
1009 results.resize_for_overwrite(rank * 2 + 2);
1011 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1012 rewriter, loc, reinterpretCastOp.getSource());
1015 results[0] = newExtractStridedMetadata.getBaseBuffer();
1019 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
1021 const unsigned sizeStartIdx = 2;
1022 const unsigned strideStartIdx = sizeStartIdx + rank;
1024 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
1025 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
1026 for (
unsigned i = 0; i < rank; ++i) {
1027 results[sizeStartIdx + i] = sizes[i];
1028 results[strideStartIdx + i] = strides[i];
1030 rewriter.
replaceOp(extractStridedMetadataOp,
1046class ExtractStridedMetadataOpMemorySpaceCastFolder
1051 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1052 PatternRewriter &rewriter)
const override {
1053 Location loc = extractStridedMetadataOp.getLoc();
1054 Value source = extractStridedMetadataOp.getSource();
1055 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
1056 if (!memSpaceCastOp)
1058 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1059 rewriter, loc, memSpaceCastOp.getSource());
1060 SmallVector<Value> results(newExtractStridedMetadata.getResults());
1067 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1068 auto baseBuffer = results[0];
1069 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1070 MemRefType::Builder newTypeBuilder(baseBufferType);
1071 newTypeBuilder.setMemorySpace(
1072 memSpaceCastOp.getResult().getType().getMemorySpace());
1073 results[0] = memref::MemorySpaceCastOp::create(
1074 rewriter, loc, Type{newTypeBuilder}, baseBuffer);
1076 results[0] =
nullptr;
1078 rewriter.
replaceOp(extractStridedMetadataOp, results);
1090class ExtractStridedMetadataOpExtractStridedMetadataFolder
1095 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1096 PatternRewriter &rewriter)
const override {
1097 auto sourceExtractStridedMetadataOp =
1098 extractStridedMetadataOp.getSource()
1099 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1100 if (!sourceExtractStridedMetadataOp)
1102 Location loc = extractStridedMetadataOp.getLoc();
1103 rewriter.
replaceOp(extractStridedMetadataOp,
1104 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1115 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1116 getExpandedStrides>,
1117 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1118 getCollapsedStride>,
1119 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1120 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1121 ExtractStridedMetadataOpCollapseShapeFolder,
1122 ExtractStridedMetadataOpExpandShapeFolder,
1123 ExtractStridedMetadataOpGetGlobalFolder,
1124 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1125 ExtractStridedMetadataOpReinterpretCastFolder,
1126 ExtractStridedMetadataOpSubviewFolder,
1127 ExtractStridedMetadataOpMemorySpaceCastFolder,
1128 ExtractStridedMetadataOpAssumeAlignmentFolder,
1129 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1135 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1136 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1137 ExtractStridedMetadataOpCollapseShapeFolder,
1138 ExtractStridedMetadataOpExpandShapeFolder,
1139 ExtractStridedMetadataOpGetGlobalFolder,
1140 ExtractStridedMetadataOpSubviewFolder,
1141 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1142 ExtractStridedMetadataOpReinterpretCastFolder,
1143 ExtractStridedMetadataOpMemorySpaceCastFolder,
1144 ExtractStridedMetadataOpAssumeAlignmentFolder,
1145 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1155struct ExpandStridedMetadataPass final
1157 ExpandStridedMetadataPass> {
1158 void runOnOperation()
override;
1163void ExpandStridedMetadataPass::runOnOperation() {
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::function_ref< Fn > function_ref
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.