25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41 struct StridedMetadata {
59 static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && ShapedType::isStatic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && ShapedType::isStatic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (
failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221 struct ExtractStridedMetadataOpSubviewFolder
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (
failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269 getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
366 auto [strides, offset] = sourceType.getStridesAndOffset();
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
373 int64_t doneStrideIdx = 0;
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454 getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (ShapedType::isStatic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495 getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
506 auto [strides, offset] = sourceType.getStridesAndOffset();
511 for (int64_t currentDim : reassocGroup) {
517 if (srcShape[currentDim] == 1)
520 int64_t currentStride = strides[currentDim];
521 lastValidStride = ShapedType::isDynamic(currentStride)
522 ? origStrides[currentDim]
525 if (!lastValidStride) {
528 MemRefType collapsedType = collapseShape.getResultType();
529 auto [collapsedStrides, collapsedOffset] =
530 collapsedType.getStridesAndOffset();
531 int64_t finalStride = collapsedStrides[groupId];
532 if (ShapedType::isDynamic(finalStride)) {
535 for (int64_t currentDim : reassocGroup) {
536 assert(srcShape[currentDim] == 1 &&
537 "We should be dealing with 1x1x...x1");
539 if (ShapedType::isDynamic(strides[currentDim]))
540 return {origStrides[currentDim]};
542 llvm_unreachable(
"We should have found a dynamic stride");
547 return {lastValidStride};
560 template <
typename ReassociativeReshapeLikeOp>
561 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
562 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
571 getReshapedStrides) {
574 Location origLoc = reshape.getLoc();
575 Value source = reshape.getSrc();
576 auto sourceType = cast<MemRefType>(source.
getType());
577 unsigned sourceRank = sourceType.getRank();
579 auto newExtractStridedMetadata =
580 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
583 auto [strides, offset] = sourceType.getStridesAndOffset();
584 MemRefType reshapeType = reshape.getResultType();
585 unsigned reshapeRank = reshapeType.getRank();
588 ShapedType::isDynamic(offset)
590 : rewriter.getIndexAttr(offset);
593 if (sourceRank == 0) {
595 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
600 finalSizes.reserve(reshapeRank);
602 finalStrides.reserve(reshapeRank);
609 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
610 for (; idx != endIdx; ++idx) {
612 getReshapedSizes(reshape, rewriter, origSizes, idx);
614 reshape, rewriter, origSizes, origStrides, idx);
616 unsigned groupSize = reshapedSizes.size();
617 for (
unsigned i = 0; i < groupSize; ++i) {
618 finalSizes.push_back(reshapedSizes[i]);
619 finalStrides.push_back(reshapedStrides[i]);
622 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
623 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
624 "We should have visited all the input dimensions");
625 assert(finalSizes.size() == reshapeRank &&
626 "We should have populated all the values");
628 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
629 finalSizes, finalStrides};
648 template <
typename ReassociativeReshapeLikeOp,
660 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
662 FailureOr<StridedMetadata> stridedMetadata =
663 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
664 rewriter, reshape, getReshapedSizes, getReshapedStrides);
665 if (
failed(stridedMetadata)) {
667 "failed to resolve reshape metadata");
671 reshape, reshape.getType(), stridedMetadata->basePtr,
672 stridedMetadata->offset, stridedMetadata->sizes,
673 stridedMetadata->strides);
691 struct ExtractStridedMetadataOpCollapseShapeFolder
695 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
697 auto collapseShapeOp =
698 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
699 if (!collapseShapeOp)
702 FailureOr<StridedMetadata> stridedMetadata =
703 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
704 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
705 if (
failed(stridedMetadata)) {
708 "failed to resolve metadata in terms of source collapse_shape op");
711 Location loc = collapseShapeOp.getLoc();
713 results.push_back(stridedMetadata->basePtr);
715 stridedMetadata->offset));
719 stridedMetadata->strides));
728 struct ExtractStridedMetadataOpExpandShapeFolder
732 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
734 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
738 FailureOr<StridedMetadata> stridedMetadata =
739 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
740 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
741 if (
failed(stridedMetadata)) {
743 op,
"failed to resolve metadata in terms of source expand_shape op");
746 Location loc = expandShapeOp.getLoc();
748 results.push_back(stridedMetadata->basePtr);
750 stridedMetadata->offset));
754 stridedMetadata->strides));
774 template <
typename AllocLikeOp>
775 struct ExtractStridedMetadataOpAllocFolder
780 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
782 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
786 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
787 if (!memRefType.getLayout().isIdentity())
789 allocLikeOp,
"alloc-like operations should have been normalized");
792 int rank = memRefType.getRank();
795 ValueRange dynamic = allocLikeOp.getDynamicSizes();
798 unsigned dynamicPos = 0;
799 for (int64_t size : memRefType.getShape()) {
800 if (ShapedType::isDynamic(size))
801 sizes.push_back(dynamic[dynamicPos++]);
809 unsigned symbolNumber = 0;
810 for (
int i = rank - 2; i >= 0; --i) {
812 assert(i + 1 + symbolNumber == sizes.size() &&
813 "The ArrayRef should encompass the last #symbolNumber sizes");
816 sizesInvolvedInStride);
821 results.reserve(rank * 2 + 2);
823 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
825 if (op.getBaseBuffer().use_empty()) {
826 results.push_back(
nullptr);
828 if (allocLikeOp.getType() == baseBufferType)
829 results.push_back(allocLikeOp);
831 results.push_back(memref::ReinterpretCastOp::create(
832 rewriter, loc, baseBufferType, allocLikeOp, offset,
865 struct ExtractStridedMetadataOpGetGlobalFolder
870 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
872 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
876 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
877 if (!memRefType.getLayout().isIdentity()) {
880 "get-global operation result should have been normalized");
884 int rank = memRefType.getRank();
888 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
889 "unexpected dynamic shape for result of `memref.get_global` op");
896 results.reserve(rank * 2 + 2);
898 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
900 if (getGlobalOp.getType() == baseBufferType)
901 results.push_back(getGlobalOp);
903 results.push_back(memref::ReinterpretCastOp::create(
904 rewriter, loc, baseBufferType, getGlobalOp, offset,
911 for (
auto size : sizes)
914 for (
auto stride : strides)
933 struct ExtractStridedMetadataOpAssumeAlignmentFolder
938 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
940 auto assumeAlignmentOp =
941 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942 if (!assumeAlignmentOp)
946 op, assumeAlignmentOp.getViewSource());
953 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
958 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
961 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
962 if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
965 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
984 class ExtractStridedMetadataOpReinterpretCastFolder
989 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
991 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
992 .getDefiningOp<memref::ReinterpretCastOp>();
993 if (!reinterpretCastOp)
996 Location loc = extractStridedMetadataOp.getLoc();
999 if (
failed(extractStridedMetadataOp.inferReturnTypes(
1000 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
1002 inferredReturnTypes)))
1004 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
1006 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
1007 unsigned rank = memrefType.getRank();
1009 results.resize_for_overwrite(rank * 2 + 2);
1011 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1012 rewriter, loc, reinterpretCastOp.getSource());
1015 results[0] = newExtractStridedMetadata.getBaseBuffer();
1019 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
1021 const unsigned sizeStartIdx = 2;
1022 const unsigned strideStartIdx = sizeStartIdx + rank;
1026 for (
unsigned i = 0; i < rank; ++i) {
1027 results[sizeStartIdx + i] = sizes[i];
1028 results[strideStartIdx + i] = strides[i];
1030 rewriter.
replaceOp(extractStridedMetadataOp,
1057 class ExtractStridedMetadataOpCastFolder
1062 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1064 Value source = extractStridedMetadataOp.getSource();
1069 Location loc = extractStridedMetadataOp.getLoc();
1072 if (
failed(extractStridedMetadataOp.inferReturnTypes(
1073 rewriter.
getContext(), loc, {castOp.getSource()},
1075 inferredReturnTypes)))
1077 "cast source's type is incompatible");
1079 auto memrefType = cast<MemRefType>(source.
getType());
1080 unsigned rank = memrefType.getRank();
1082 results.resize_for_overwrite(rank * 2 + 2);
1084 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1085 rewriter, loc, castOp.getSource());
1088 results[0] = newExtractStridedMetadata.getBaseBuffer();
1090 auto getConstantOrValue = [&rewriter](int64_t constant,
1092 return ShapedType::isStatic(constant)
1097 auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
1098 assert(sourceStrides.size() == rank &&
"unexpected number of strides");
1102 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1104 const unsigned sizeStartIdx = 2;
1105 const unsigned strideStartIdx = sizeStartIdx + rank;
1110 for (
unsigned i = 0; i < rank; ++i) {
1111 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1112 results[strideStartIdx + i] =
1113 getConstantOrValue(sourceStrides[i], strides[i]);
1115 rewriter.
replaceOp(extractStridedMetadataOp,
1131 class ExtractStridedMetadataOpMemorySpaceCastFolder
1136 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1138 Location loc = extractStridedMetadataOp.getLoc();
1139 Value source = extractStridedMetadataOp.getSource();
1140 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
1141 if (!memSpaceCastOp)
1143 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1144 rewriter, loc, memSpaceCastOp.getSource());
1152 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1153 auto baseBuffer = results[0];
1154 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1156 newTypeBuilder.setMemorySpace(
1157 memSpaceCastOp.getResult().getType().getMemorySpace());
1158 results[0] = memref::MemorySpaceCastOp::create(
1159 rewriter, loc,
Type{newTypeBuilder}, baseBuffer);
1161 results[0] =
nullptr;
1163 rewriter.
replaceOp(extractStridedMetadataOp, results);
1175 class ExtractStridedMetadataOpExtractStridedMetadataFolder
1180 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1182 auto sourceExtractStridedMetadataOp =
1183 extractStridedMetadataOp.getSource()
1184 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1185 if (!sourceExtractStridedMetadataOp)
1187 Location loc = extractStridedMetadataOp.getLoc();
1188 rewriter.
replaceOp(extractStridedMetadataOp,
1189 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1200 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1201 getExpandedStrides>,
1202 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1203 getCollapsedStride>,
1204 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1205 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1206 ExtractStridedMetadataOpCollapseShapeFolder,
1207 ExtractStridedMetadataOpExpandShapeFolder,
1208 ExtractStridedMetadataOpGetGlobalFolder,
1209 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1210 ExtractStridedMetadataOpReinterpretCastFolder,
1211 ExtractStridedMetadataOpSubviewFolder,
1212 ExtractStridedMetadataOpCastFolder,
1213 ExtractStridedMetadataOpMemorySpaceCastFolder,
1214 ExtractStridedMetadataOpAssumeAlignmentFolder,
1215 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1221 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1222 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1223 ExtractStridedMetadataOpCollapseShapeFolder,
1224 ExtractStridedMetadataOpExpandShapeFolder,
1225 ExtractStridedMetadataOpGetGlobalFolder,
1226 ExtractStridedMetadataOpSubviewFolder,
1227 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1228 ExtractStridedMetadataOpReinterpretCastFolder,
1229 ExtractStridedMetadataOpCastFolder,
1230 ExtractStridedMetadataOpMemorySpaceCastFolder,
1231 ExtractStridedMetadataOpAssumeAlignmentFolder,
1232 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1242 struct ExpandStridedMetadataPass final
1243 :
public memref::impl::ExpandStridedMetadataPassBase<
1244 ExpandStridedMetadataPass> {
1245 void runOnOperation()
override;
1250 void ExpandStridedMetadataPass::runOnOperation() {
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.