25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41 struct StridedMetadata {
59 static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && !ShapedType::isDynamic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && !ShapedType::isDynamic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221 struct ExtractStridedMetadataOpSubviewFolder
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269 getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
373 int64_t doneStrideIdx = 0;
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454 getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (!ShapedType::isDynamic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495 getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
510 for (int64_t currentDim : reassocGroup) {
516 if (srcShape[currentDim] == 1)
519 int64_t currentStride = strides[currentDim];
520 groupStrides.push_back(ShapedType::isDynamic(currentStride)
521 ? origStrides[currentDim]
524 if (groupStrides.empty()) {
527 MemRefType collapsedType = collapseShape.getResultType();
528 auto [collapsedStrides, collapsedOffset] =
530 int64_t finalStride = collapsedStrides[groupId];
531 if (ShapedType::isDynamic(finalStride)) {
534 for (int64_t currentDim : reassocGroup) {
535 assert(srcShape[currentDim] == 1 &&
536 "We should be dealing with 1x1x...x1");
538 if (ShapedType::isDynamic(strides[currentDim]))
539 return {origStrides[currentDim]};
541 llvm_unreachable(
"We should have found a dynamic stride");
564 template <
typename ReassociativeReshapeLikeOp>
565 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
566 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
575 getReshapedStrides) {
578 Location origLoc = reshape.getLoc();
579 Value source = reshape.getSrc();
580 auto sourceType = cast<MemRefType>(source.
getType());
581 unsigned sourceRank = sourceType.getRank();
583 auto newExtractStridedMetadata =
584 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
588 MemRefType reshapeType = reshape.getResultType();
589 unsigned reshapeRank = reshapeType.getRank();
592 ShapedType::isDynamic(offset)
594 : rewriter.getIndexAttr(offset);
597 if (sourceRank == 0) {
599 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
604 finalSizes.reserve(reshapeRank);
606 finalStrides.reserve(reshapeRank);
613 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
614 for (; idx != endIdx; ++idx) {
616 getReshapedSizes(reshape, rewriter, origSizes, idx);
618 reshape, rewriter, origSizes, origStrides, idx);
620 unsigned groupSize = reshapedSizes.size();
621 for (
unsigned i = 0; i < groupSize; ++i) {
622 finalSizes.push_back(reshapedSizes[i]);
623 finalStrides.push_back(reshapedStrides[i]);
626 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
627 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
628 "We should have visited all the input dimensions");
629 assert(finalSizes.size() == reshapeRank &&
630 "We should have populated all the values");
632 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
633 finalSizes, finalStrides};
652 template <
typename ReassociativeReshapeLikeOp,
664 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
666 FailureOr<StridedMetadata> stridedMetadata =
667 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
668 rewriter, reshape, getReshapedSizes, getReshapedStrides);
669 if (failed(stridedMetadata)) {
671 "failed to resolve reshape metadata");
675 reshape, reshape.getType(), stridedMetadata->basePtr,
676 stridedMetadata->offset, stridedMetadata->sizes,
677 stridedMetadata->strides);
695 struct ExtractStridedMetadataOpCollapseShapeFolder
699 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
701 auto collapseShapeOp =
702 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
703 if (!collapseShapeOp)
706 FailureOr<StridedMetadata> stridedMetadata =
707 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
708 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
709 if (failed(stridedMetadata)) {
712 "failed to resolve metadata in terms of source collapse_shape op");
715 Location loc = collapseShapeOp.getLoc();
717 results.push_back(stridedMetadata->basePtr);
719 stridedMetadata->offset));
723 stridedMetadata->strides));
732 struct ExtractStridedMetadataOpExpandShapeFolder
736 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
738 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
742 FailureOr<StridedMetadata> stridedMetadata =
743 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
744 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
745 if (failed(stridedMetadata)) {
747 op,
"failed to resolve metadata in terms of source expand_shape op");
750 Location loc = expandShapeOp.getLoc();
752 results.push_back(stridedMetadata->basePtr);
754 stridedMetadata->offset));
758 stridedMetadata->strides));
778 template <
typename AllocLikeOp>
779 struct ExtractStridedMetadataOpAllocFolder
784 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
786 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
790 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
791 if (!memRefType.getLayout().isIdentity())
793 allocLikeOp,
"alloc-like operations should have been normalized");
796 int rank = memRefType.getRank();
799 ValueRange dynamic = allocLikeOp.getDynamicSizes();
802 unsigned dynamicPos = 0;
803 for (int64_t size : memRefType.getShape()) {
804 if (ShapedType::isDynamic(size))
805 sizes.push_back(dynamic[dynamicPos++]);
813 unsigned symbolNumber = 0;
814 for (
int i = rank - 2; i >= 0; --i) {
816 assert(i + 1 + symbolNumber == sizes.size() &&
817 "The ArrayRef should encompass the last #symbolNumber sizes");
820 sizesInvolvedInStride);
825 results.reserve(rank * 2 + 2);
827 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
830 results.push_back(
nullptr);
832 if (allocLikeOp.getType() == baseBufferType)
833 results.push_back(allocLikeOp);
835 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
836 loc, baseBufferType, allocLikeOp, offset,
842 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
869 struct ExtractStridedMetadataOpGetGlobalFolder
874 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
876 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
880 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
881 if (!memRefType.getLayout().isIdentity()) {
884 "get-global operation result should have been normalized");
888 int rank = memRefType.getRank();
892 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
893 "unexpected dynamic shape for result of `memref.get_global` op");
900 results.reserve(rank * 2 + 2);
902 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
904 if (getGlobalOp.getType() == baseBufferType)
905 results.push_back(getGlobalOp);
907 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
908 loc, baseBufferType, getGlobalOp, offset,
913 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
915 for (
auto size : sizes)
916 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, size));
918 for (
auto stride : strides)
919 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, stride));
928 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
933 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
936 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
940 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
959 class ExtractStridedMetadataOpReinterpretCastFolder
964 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
966 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
967 .getDefiningOp<memref::ReinterpretCastOp>();
968 if (!reinterpretCastOp)
971 Location loc = extractStridedMetadataOp.getLoc();
974 if (failed(extractStridedMetadataOp.inferReturnTypes(
975 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
977 inferredReturnTypes)))
979 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
981 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
982 unsigned rank = memrefType.getRank();
984 results.resize_for_overwrite(rank * 2 + 2);
986 auto newExtractStridedMetadata =
987 rewriter.
create<memref::ExtractStridedMetadataOp>(
988 loc, reinterpretCastOp.getSource());
991 results[0] = newExtractStridedMetadata.getBaseBuffer();
995 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
997 const unsigned sizeStartIdx = 2;
998 const unsigned strideStartIdx = sizeStartIdx + rank;
1002 for (
unsigned i = 0; i < rank; ++i) {
1003 results[sizeStartIdx + i] = sizes[i];
1004 results[strideStartIdx + i] = strides[i];
1006 rewriter.
replaceOp(extractStridedMetadataOp,
1033 class ExtractStridedMetadataOpCastFolder
1038 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1040 Value source = extractStridedMetadataOp.getSource();
1045 Location loc = extractStridedMetadataOp.getLoc();
1048 if (failed(extractStridedMetadataOp.inferReturnTypes(
1049 rewriter.
getContext(), loc, {castOp.getSource()},
1051 inferredReturnTypes)))
1053 "cast source's type is incompatible");
1055 auto memrefType = cast<MemRefType>(source.
getType());
1056 unsigned rank = memrefType.getRank();
1058 results.resize_for_overwrite(rank * 2 + 2);
1060 auto newExtractStridedMetadata =
1061 rewriter.
create<memref::ExtractStridedMetadataOp>(loc,
1062 castOp.getSource());
1065 results[0] = newExtractStridedMetadata.getBaseBuffer();
1067 auto getConstantOrValue = [&rewriter](int64_t constant,
1069 return !ShapedType::isDynamic(constant)
1075 assert(sourceStrides.size() == rank &&
"unexpected number of strides");
1079 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1081 const unsigned sizeStartIdx = 2;
1082 const unsigned strideStartIdx = sizeStartIdx + rank;
1087 for (
unsigned i = 0; i < rank; ++i) {
1088 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1089 results[strideStartIdx + i] =
1090 getConstantOrValue(sourceStrides[i], strides[i]);
1092 rewriter.
replaceOp(extractStridedMetadataOp,
1108 class ExtractStridedMetadataOpMemorySpaceCastFolder
1113 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1115 Location loc = extractStridedMetadataOp.getLoc();
1116 Value source = extractStridedMetadataOp.getSource();
1117 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
1118 if (!memSpaceCastOp)
1120 auto newExtractStridedMetadata =
1121 rewriter.
create<memref::ExtractStridedMetadataOp>(
1122 loc, memSpaceCastOp.getSource());
1130 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1131 auto baseBuffer = results[0];
1132 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1134 newTypeBuilder.setMemorySpace(
1135 memSpaceCastOp.getResult().getType().getMemorySpace());
1136 results[0] = rewriter.
create<memref::MemorySpaceCastOp>(
1137 loc,
Type{newTypeBuilder}, baseBuffer);
1139 results[0] =
nullptr;
1141 rewriter.
replaceOp(extractStridedMetadataOp, results);
1153 class ExtractStridedMetadataOpExtractStridedMetadataFolder
1158 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1160 auto sourceExtractStridedMetadataOp =
1161 extractStridedMetadataOp.getSource()
1162 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1163 if (!sourceExtractStridedMetadataOp)
1165 Location loc = extractStridedMetadataOp.getLoc();
1166 rewriter.
replaceOp(extractStridedMetadataOp,
1167 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1177 patterns.
add<SubviewFolder,
1178 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1179 getExpandedStrides>,
1180 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1181 getCollapsedStride>,
1182 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1183 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1184 ExtractStridedMetadataOpCollapseShapeFolder,
1185 ExtractStridedMetadataOpExpandShapeFolder,
1186 ExtractStridedMetadataOpGetGlobalFolder,
1187 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1188 ExtractStridedMetadataOpReinterpretCastFolder,
1189 ExtractStridedMetadataOpSubviewFolder,
1190 ExtractStridedMetadataOpCastFolder,
1191 ExtractStridedMetadataOpMemorySpaceCastFolder,
1192 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1198 patterns.
add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1199 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1200 ExtractStridedMetadataOpCollapseShapeFolder,
1201 ExtractStridedMetadataOpExpandShapeFolder,
1202 ExtractStridedMetadataOpGetGlobalFolder,
1203 ExtractStridedMetadataOpSubviewFolder,
1204 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1205 ExtractStridedMetadataOpReinterpretCastFolder,
1206 ExtractStridedMetadataOpCastFolder,
1207 ExtractStridedMetadataOpMemorySpaceCastFolder,
1208 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1218 struct ExpandStridedMetadataPass final
1219 :
public memref::impl::ExpandStridedMetadataBase<
1220 ExpandStridedMetadataPass> {
1221 void runOnOperation()
override;
1226 void ExpandStridedMetadataPass::runOnOperation() {
1233 return std::make_unique<ExpandStridedMetadataPass>();
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
bool use_empty()
Returns true if this operation has no uses.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.