25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41 struct StridedMetadata {
59 static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && !ShapedType::isDynamic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && !ShapedType::isDynamic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221 struct ExtractStridedMetadataOpSubviewFolder
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269 getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
366 auto [strides, offset] = sourceType.getStridesAndOffset();
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
373 int64_t doneStrideIdx = 0;
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454 getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (!ShapedType::isDynamic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495 getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
506 auto [strides, offset] = sourceType.getStridesAndOffset();
511 for (int64_t currentDim : reassocGroup) {
517 if (srcShape[currentDim] == 1)
520 int64_t currentStride = strides[currentDim];
521 lastValidStride = ShapedType::isDynamic(currentStride)
522 ? origStrides[currentDim]
525 if (!lastValidStride) {
528 MemRefType collapsedType = collapseShape.getResultType();
529 auto [collapsedStrides, collapsedOffset] =
530 collapsedType.getStridesAndOffset();
531 int64_t finalStride = collapsedStrides[groupId];
532 if (ShapedType::isDynamic(finalStride)) {
535 for (int64_t currentDim : reassocGroup) {
536 assert(srcShape[currentDim] == 1 &&
537 "We should be dealing with 1x1x...x1");
539 if (ShapedType::isDynamic(strides[currentDim]))
540 return {origStrides[currentDim]};
542 llvm_unreachable(
"We should have found a dynamic stride");
547 return {lastValidStride};
560 template <
typename ReassociativeReshapeLikeOp>
561 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
562 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
571 getReshapedStrides) {
574 Location origLoc = reshape.getLoc();
575 Value source = reshape.getSrc();
576 auto sourceType = cast<MemRefType>(source.
getType());
577 unsigned sourceRank = sourceType.getRank();
579 auto newExtractStridedMetadata =
580 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
583 auto [strides, offset] = sourceType.getStridesAndOffset();
584 MemRefType reshapeType = reshape.getResultType();
585 unsigned reshapeRank = reshapeType.getRank();
588 ShapedType::isDynamic(offset)
590 : rewriter.getIndexAttr(offset);
593 if (sourceRank == 0) {
595 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
600 finalSizes.reserve(reshapeRank);
602 finalStrides.reserve(reshapeRank);
609 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
610 for (; idx != endIdx; ++idx) {
612 getReshapedSizes(reshape, rewriter, origSizes, idx);
614 reshape, rewriter, origSizes, origStrides, idx);
616 unsigned groupSize = reshapedSizes.size();
617 for (
unsigned i = 0; i < groupSize; ++i) {
618 finalSizes.push_back(reshapedSizes[i]);
619 finalStrides.push_back(reshapedStrides[i]);
622 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
623 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
624 "We should have visited all the input dimensions");
625 assert(finalSizes.size() == reshapeRank &&
626 "We should have populated all the values");
628 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
629 finalSizes, finalStrides};
648 template <
typename ReassociativeReshapeLikeOp,
660 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
662 FailureOr<StridedMetadata> stridedMetadata =
663 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
664 rewriter, reshape, getReshapedSizes, getReshapedStrides);
665 if (failed(stridedMetadata)) {
667 "failed to resolve reshape metadata");
671 reshape, reshape.getType(), stridedMetadata->basePtr,
672 stridedMetadata->offset, stridedMetadata->sizes,
673 stridedMetadata->strides);
691 struct ExtractStridedMetadataOpCollapseShapeFolder
695 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
697 auto collapseShapeOp =
698 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
699 if (!collapseShapeOp)
702 FailureOr<StridedMetadata> stridedMetadata =
703 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
704 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
705 if (failed(stridedMetadata)) {
708 "failed to resolve metadata in terms of source collapse_shape op");
711 Location loc = collapseShapeOp.getLoc();
713 results.push_back(stridedMetadata->basePtr);
715 stridedMetadata->offset));
719 stridedMetadata->strides));
728 struct ExtractStridedMetadataOpExpandShapeFolder
732 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
734 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
738 FailureOr<StridedMetadata> stridedMetadata =
739 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
740 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
741 if (failed(stridedMetadata)) {
743 op,
"failed to resolve metadata in terms of source expand_shape op");
746 Location loc = expandShapeOp.getLoc();
748 results.push_back(stridedMetadata->basePtr);
750 stridedMetadata->offset));
754 stridedMetadata->strides));
774 template <
typename AllocLikeOp>
775 struct ExtractStridedMetadataOpAllocFolder
780 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
782 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
786 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
787 if (!memRefType.getLayout().isIdentity())
789 allocLikeOp,
"alloc-like operations should have been normalized");
792 int rank = memRefType.getRank();
795 ValueRange dynamic = allocLikeOp.getDynamicSizes();
798 unsigned dynamicPos = 0;
799 for (int64_t size : memRefType.getShape()) {
800 if (ShapedType::isDynamic(size))
801 sizes.push_back(dynamic[dynamicPos++]);
809 unsigned symbolNumber = 0;
810 for (
int i = rank - 2; i >= 0; --i) {
812 assert(i + 1 + symbolNumber == sizes.size() &&
813 "The ArrayRef should encompass the last #symbolNumber sizes");
816 sizesInvolvedInStride);
821 results.reserve(rank * 2 + 2);
823 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
825 if (op.getBaseBuffer().use_empty()) {
826 results.push_back(
nullptr);
828 if (allocLikeOp.getType() == baseBufferType)
829 results.push_back(allocLikeOp);
831 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
832 loc, baseBufferType, allocLikeOp, offset,
838 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
865 struct ExtractStridedMetadataOpGetGlobalFolder
870 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
872 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
876 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
877 if (!memRefType.getLayout().isIdentity()) {
880 "get-global operation result should have been normalized");
884 int rank = memRefType.getRank();
888 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
889 "unexpected dynamic shape for result of `memref.get_global` op");
896 results.reserve(rank * 2 + 2);
898 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
900 if (getGlobalOp.getType() == baseBufferType)
901 results.push_back(getGlobalOp);
903 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
904 loc, baseBufferType, getGlobalOp, offset,
909 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
911 for (
auto size : sizes)
912 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, size));
914 for (
auto stride : strides)
915 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, stride));
933 struct ExtractStridedMetadataOpAssumeAlignmentFolder
938 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
940 auto assumeAlignmentOp =
941 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942 if (!assumeAlignmentOp)
946 op, assumeAlignmentOp.getViewSource());
953 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
958 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
961 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
965 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
984 class ExtractStridedMetadataOpReinterpretCastFolder
989 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
991 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
992 .getDefiningOp<memref::ReinterpretCastOp>();
993 if (!reinterpretCastOp)
996 Location loc = extractStridedMetadataOp.getLoc();
999 if (failed(extractStridedMetadataOp.inferReturnTypes(
1000 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
1002 inferredReturnTypes)))
1004 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
1006 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
1007 unsigned rank = memrefType.getRank();
1009 results.resize_for_overwrite(rank * 2 + 2);
1011 auto newExtractStridedMetadata =
1012 rewriter.
create<memref::ExtractStridedMetadataOp>(
1013 loc, reinterpretCastOp.getSource());
1016 results[0] = newExtractStridedMetadata.getBaseBuffer();
1020 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
1022 const unsigned sizeStartIdx = 2;
1023 const unsigned strideStartIdx = sizeStartIdx + rank;
1027 for (
unsigned i = 0; i < rank; ++i) {
1028 results[sizeStartIdx + i] = sizes[i];
1029 results[strideStartIdx + i] = strides[i];
1031 rewriter.
replaceOp(extractStridedMetadataOp,
1058 class ExtractStridedMetadataOpCastFolder
1063 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1065 Value source = extractStridedMetadataOp.getSource();
1070 Location loc = extractStridedMetadataOp.getLoc();
1073 if (failed(extractStridedMetadataOp.inferReturnTypes(
1074 rewriter.
getContext(), loc, {castOp.getSource()},
1076 inferredReturnTypes)))
1078 "cast source's type is incompatible");
1080 auto memrefType = cast<MemRefType>(source.
getType());
1081 unsigned rank = memrefType.getRank();
1083 results.resize_for_overwrite(rank * 2 + 2);
1085 auto newExtractStridedMetadata =
1086 rewriter.
create<memref::ExtractStridedMetadataOp>(loc,
1087 castOp.getSource());
1090 results[0] = newExtractStridedMetadata.getBaseBuffer();
1092 auto getConstantOrValue = [&rewriter](int64_t constant,
1094 return !ShapedType::isDynamic(constant)
1099 auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
1100 assert(sourceStrides.size() == rank &&
"unexpected number of strides");
1104 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1106 const unsigned sizeStartIdx = 2;
1107 const unsigned strideStartIdx = sizeStartIdx + rank;
1112 for (
unsigned i = 0; i < rank; ++i) {
1113 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1114 results[strideStartIdx + i] =
1115 getConstantOrValue(sourceStrides[i], strides[i]);
1117 rewriter.
replaceOp(extractStridedMetadataOp,
1133 class ExtractStridedMetadataOpMemorySpaceCastFolder
1138 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1140 Location loc = extractStridedMetadataOp.getLoc();
1141 Value source = extractStridedMetadataOp.getSource();
1142 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
1143 if (!memSpaceCastOp)
1145 auto newExtractStridedMetadata =
1146 rewriter.
create<memref::ExtractStridedMetadataOp>(
1147 loc, memSpaceCastOp.getSource());
1155 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1156 auto baseBuffer = results[0];
1157 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1159 newTypeBuilder.setMemorySpace(
1160 memSpaceCastOp.getResult().getType().getMemorySpace());
1161 results[0] = rewriter.
create<memref::MemorySpaceCastOp>(
1162 loc,
Type{newTypeBuilder}, baseBuffer);
1164 results[0] =
nullptr;
1166 rewriter.
replaceOp(extractStridedMetadataOp, results);
1178 class ExtractStridedMetadataOpExtractStridedMetadataFolder
1183 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1185 auto sourceExtractStridedMetadataOp =
1186 extractStridedMetadataOp.getSource()
1187 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1188 if (!sourceExtractStridedMetadataOp)
1190 Location loc = extractStridedMetadataOp.getLoc();
1191 rewriter.
replaceOp(extractStridedMetadataOp,
1192 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1203 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1204 getExpandedStrides>,
1205 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1206 getCollapsedStride>,
1207 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1208 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1209 ExtractStridedMetadataOpCollapseShapeFolder,
1210 ExtractStridedMetadataOpExpandShapeFolder,
1211 ExtractStridedMetadataOpGetGlobalFolder,
1212 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1213 ExtractStridedMetadataOpReinterpretCastFolder,
1214 ExtractStridedMetadataOpSubviewFolder,
1215 ExtractStridedMetadataOpCastFolder,
1216 ExtractStridedMetadataOpMemorySpaceCastFolder,
1217 ExtractStridedMetadataOpAssumeAlignmentFolder,
1218 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1224 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1225 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1226 ExtractStridedMetadataOpCollapseShapeFolder,
1227 ExtractStridedMetadataOpExpandShapeFolder,
1228 ExtractStridedMetadataOpGetGlobalFolder,
1229 ExtractStridedMetadataOpSubviewFolder,
1230 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1231 ExtractStridedMetadataOpReinterpretCastFolder,
1232 ExtractStridedMetadataOpCastFolder,
1233 ExtractStridedMetadataOpMemorySpaceCastFolder,
1234 ExtractStridedMetadataOpAssumeAlignmentFolder,
1235 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1245 struct ExpandStridedMetadataPass final
1246 :
public memref::impl::ExpandStridedMetadataPassBase<
1247 ExpandStridedMetadataPass> {
1248 void runOnOperation()
override;
1253 void ExpandStridedMetadataPass::runOnOperation() {
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.