25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41 struct StridedMetadata {
59 static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && !ShapedType::isDynamic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && !ShapedType::isDynamic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221 struct ExtractStridedMetadataOpSubviewFolder
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269 getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
373 int64_t doneStrideIdx = 0;
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454 getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (!ShapedType::isDynamic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495 getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
512 for (int64_t currentDim : reassocGroup) {
518 if (srcShape[currentDim] == 1)
521 int64_t currentStride = strides[currentDim];
522 lastValidStride = ShapedType::isDynamic(currentStride)
523 ? origStrides[currentDim]
526 if (!lastValidStride) {
529 MemRefType collapsedType = collapseShape.getResultType();
530 auto [collapsedStrides, collapsedOffset] =
532 int64_t finalStride = collapsedStrides[groupId];
533 if (ShapedType::isDynamic(finalStride)) {
536 for (int64_t currentDim : reassocGroup) {
537 assert(srcShape[currentDim] == 1 &&
538 "We should be dealing with 1x1x...x1");
540 if (ShapedType::isDynamic(strides[currentDim]))
541 return {origStrides[currentDim]};
543 llvm_unreachable(
"We should have found a dynamic stride");
548 return {lastValidStride};
561 template <
typename ReassociativeReshapeLikeOp>
562 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
563 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
572 getReshapedStrides) {
575 Location origLoc = reshape.getLoc();
576 Value source = reshape.getSrc();
577 auto sourceType = cast<MemRefType>(source.
getType());
578 unsigned sourceRank = sourceType.getRank();
580 auto newExtractStridedMetadata =
581 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
585 MemRefType reshapeType = reshape.getResultType();
586 unsigned reshapeRank = reshapeType.getRank();
589 ShapedType::isDynamic(offset)
591 : rewriter.getIndexAttr(offset);
594 if (sourceRank == 0) {
596 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
601 finalSizes.reserve(reshapeRank);
603 finalStrides.reserve(reshapeRank);
610 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
611 for (; idx != endIdx; ++idx) {
613 getReshapedSizes(reshape, rewriter, origSizes, idx);
615 reshape, rewriter, origSizes, origStrides, idx);
617 unsigned groupSize = reshapedSizes.size();
618 for (
unsigned i = 0; i < groupSize; ++i) {
619 finalSizes.push_back(reshapedSizes[i]);
620 finalStrides.push_back(reshapedStrides[i]);
623 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
624 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
625 "We should have visited all the input dimensions");
626 assert(finalSizes.size() == reshapeRank &&
627 "We should have populated all the values");
629 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
630 finalSizes, finalStrides};
649 template <
typename ReassociativeReshapeLikeOp,
661 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
663 FailureOr<StridedMetadata> stridedMetadata =
664 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
665 rewriter, reshape, getReshapedSizes, getReshapedStrides);
666 if (failed(stridedMetadata)) {
668 "failed to resolve reshape metadata");
672 reshape, reshape.getType(), stridedMetadata->basePtr,
673 stridedMetadata->offset, stridedMetadata->sizes,
674 stridedMetadata->strides);
692 struct ExtractStridedMetadataOpCollapseShapeFolder
696 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
698 auto collapseShapeOp =
699 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
700 if (!collapseShapeOp)
703 FailureOr<StridedMetadata> stridedMetadata =
704 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
705 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
706 if (failed(stridedMetadata)) {
709 "failed to resolve metadata in terms of source collapse_shape op");
712 Location loc = collapseShapeOp.getLoc();
714 results.push_back(stridedMetadata->basePtr);
716 stridedMetadata->offset));
720 stridedMetadata->strides));
729 struct ExtractStridedMetadataOpExpandShapeFolder
733 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
735 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
739 FailureOr<StridedMetadata> stridedMetadata =
740 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
741 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
742 if (failed(stridedMetadata)) {
744 op,
"failed to resolve metadata in terms of source expand_shape op");
747 Location loc = expandShapeOp.getLoc();
749 results.push_back(stridedMetadata->basePtr);
751 stridedMetadata->offset));
755 stridedMetadata->strides));
775 template <
typename AllocLikeOp>
776 struct ExtractStridedMetadataOpAllocFolder
781 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
783 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
787 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
788 if (!memRefType.getLayout().isIdentity())
790 allocLikeOp,
"alloc-like operations should have been normalized");
793 int rank = memRefType.getRank();
796 ValueRange dynamic = allocLikeOp.getDynamicSizes();
799 unsigned dynamicPos = 0;
800 for (int64_t size : memRefType.getShape()) {
801 if (ShapedType::isDynamic(size))
802 sizes.push_back(dynamic[dynamicPos++]);
810 unsigned symbolNumber = 0;
811 for (
int i = rank - 2; i >= 0; --i) {
813 assert(i + 1 + symbolNumber == sizes.size() &&
814 "The ArrayRef should encompass the last #symbolNumber sizes");
817 sizesInvolvedInStride);
822 results.reserve(rank * 2 + 2);
824 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
826 if (op.getBaseBuffer().use_empty()) {
827 results.push_back(
nullptr);
829 if (allocLikeOp.getType() == baseBufferType)
830 results.push_back(allocLikeOp);
832 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
833 loc, baseBufferType, allocLikeOp, offset,
839 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
866 struct ExtractStridedMetadataOpGetGlobalFolder
871 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
873 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
877 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
878 if (!memRefType.getLayout().isIdentity()) {
881 "get-global operation result should have been normalized");
885 int rank = memRefType.getRank();
889 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
890 "unexpected dynamic shape for result of `memref.get_global` op");
897 results.reserve(rank * 2 + 2);
899 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
901 if (getGlobalOp.getType() == baseBufferType)
902 results.push_back(getGlobalOp);
904 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
905 loc, baseBufferType, getGlobalOp, offset,
910 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
912 for (
auto size : sizes)
913 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, size));
915 for (
auto stride : strides)
916 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, stride));
925 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
930 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
933 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
937 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
956 class ExtractStridedMetadataOpReinterpretCastFolder
961 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
963 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
964 .getDefiningOp<memref::ReinterpretCastOp>();
965 if (!reinterpretCastOp)
968 Location loc = extractStridedMetadataOp.getLoc();
971 if (failed(extractStridedMetadataOp.inferReturnTypes(
972 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
974 inferredReturnTypes)))
976 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
978 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
979 unsigned rank = memrefType.getRank();
981 results.resize_for_overwrite(rank * 2 + 2);
983 auto newExtractStridedMetadata =
984 rewriter.
create<memref::ExtractStridedMetadataOp>(
985 loc, reinterpretCastOp.getSource());
988 results[0] = newExtractStridedMetadata.getBaseBuffer();
992 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
994 const unsigned sizeStartIdx = 2;
995 const unsigned strideStartIdx = sizeStartIdx + rank;
999 for (
unsigned i = 0; i < rank; ++i) {
1000 results[sizeStartIdx + i] = sizes[i];
1001 results[strideStartIdx + i] = strides[i];
1003 rewriter.
replaceOp(extractStridedMetadataOp,
1030 class ExtractStridedMetadataOpCastFolder
1035 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1037 Value source = extractStridedMetadataOp.getSource();
1042 Location loc = extractStridedMetadataOp.getLoc();
1045 if (failed(extractStridedMetadataOp.inferReturnTypes(
1046 rewriter.
getContext(), loc, {castOp.getSource()},
1048 inferredReturnTypes)))
1050 "cast source's type is incompatible");
1052 auto memrefType = cast<MemRefType>(source.
getType());
1053 unsigned rank = memrefType.getRank();
1055 results.resize_for_overwrite(rank * 2 + 2);
1057 auto newExtractStridedMetadata =
1058 rewriter.
create<memref::ExtractStridedMetadataOp>(loc,
1059 castOp.getSource());
1062 results[0] = newExtractStridedMetadata.getBaseBuffer();
1064 auto getConstantOrValue = [&rewriter](int64_t constant,
1066 return !ShapedType::isDynamic(constant)
1072 assert(sourceStrides.size() == rank &&
"unexpected number of strides");
1076 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1078 const unsigned sizeStartIdx = 2;
1079 const unsigned strideStartIdx = sizeStartIdx + rank;
1084 for (
unsigned i = 0; i < rank; ++i) {
1085 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1086 results[strideStartIdx + i] =
1087 getConstantOrValue(sourceStrides[i], strides[i]);
1089 rewriter.
replaceOp(extractStridedMetadataOp,
1105 class ExtractStridedMetadataOpMemorySpaceCastFolder
1110 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1112 Location loc = extractStridedMetadataOp.getLoc();
1113 Value source = extractStridedMetadataOp.getSource();
1114 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
1115 if (!memSpaceCastOp)
1117 auto newExtractStridedMetadata =
1118 rewriter.
create<memref::ExtractStridedMetadataOp>(
1119 loc, memSpaceCastOp.getSource());
1127 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1128 auto baseBuffer = results[0];
1129 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1131 newTypeBuilder.setMemorySpace(
1132 memSpaceCastOp.getResult().getType().getMemorySpace());
1133 results[0] = rewriter.
create<memref::MemorySpaceCastOp>(
1134 loc,
Type{newTypeBuilder}, baseBuffer);
1136 results[0] =
nullptr;
1138 rewriter.
replaceOp(extractStridedMetadataOp, results);
1150 class ExtractStridedMetadataOpExtractStridedMetadataFolder
1155 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1157 auto sourceExtractStridedMetadataOp =
1158 extractStridedMetadataOp.getSource()
1159 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1160 if (!sourceExtractStridedMetadataOp)
1162 Location loc = extractStridedMetadataOp.getLoc();
1163 rewriter.
replaceOp(extractStridedMetadataOp,
1164 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1174 patterns.
add<SubviewFolder,
1175 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1176 getExpandedStrides>,
1177 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1178 getCollapsedStride>,
1179 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1180 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1181 ExtractStridedMetadataOpCollapseShapeFolder,
1182 ExtractStridedMetadataOpExpandShapeFolder,
1183 ExtractStridedMetadataOpGetGlobalFolder,
1184 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1185 ExtractStridedMetadataOpReinterpretCastFolder,
1186 ExtractStridedMetadataOpSubviewFolder,
1187 ExtractStridedMetadataOpCastFolder,
1188 ExtractStridedMetadataOpMemorySpaceCastFolder,
1189 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1195 patterns.
add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1196 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1197 ExtractStridedMetadataOpCollapseShapeFolder,
1198 ExtractStridedMetadataOpExpandShapeFolder,
1199 ExtractStridedMetadataOpGetGlobalFolder,
1200 ExtractStridedMetadataOpSubviewFolder,
1201 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1202 ExtractStridedMetadataOpReinterpretCastFolder,
1203 ExtractStridedMetadataOpCastFolder,
1204 ExtractStridedMetadataOpMemorySpaceCastFolder,
1205 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1215 struct ExpandStridedMetadataPass final
1216 :
public memref::impl::ExpandStridedMetadataBase<
1217 ExpandStridedMetadataPass> {
1218 void runOnOperation()
override;
1223 void ExpandStridedMetadataPass::runOnOperation() {
1230 return std::make_unique<ExpandStridedMetadataPass>();
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This is a builder type that keeps local references to arguments.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.