25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41 struct StridedMetadata {
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && !ShapedType::isDynamic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && !ShapedType::isDynamic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (
failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221 struct ExtractStridedMetadataOpSubviewFolder
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (
failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
269 getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
276 unsigned groupSize = reassocGroup.size();
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
284 for (
unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
291 productOfAllStaticSizes *= dimSize;
301 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
305 return expandedSizes;
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
345 std::optional<int64_t> dynSizeIdx;
349 uint64_t currentStride = 1;
351 for (
int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.
getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
360 currentStride *= dimSize;
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.
getType());
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
373 int64_t doneStrideIdx = 0;
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
405 return expandedStrides;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (
unsigned i = 0; i < groupSize; ++i) {
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
430 inputValues.push_back(isDynamic(maybeConstant)
454 getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
458 MemRefType collapseShapeType = collapseShape.getResultType();
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (!ShapedType::isDynamic(size)) {
463 return collapsedSize;
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.
getType());
473 collapseShape.getReassociationIndices()[groupId];
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
479 return collapsedSize;
495 getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.
getType());
510 for (int64_t currentDim : reassocGroup) {
516 if (srcShape[currentDim] == 1)
519 int64_t currentStride = strides[currentDim];
520 groupStrides.push_back(ShapedType::isDynamic(currentStride)
521 ? origStrides[currentDim]
524 if (groupStrides.empty()) {
527 MemRefType collapsedType = collapseShape.getResultType();
528 auto [collapsedStrides, collapsedOffset] =
530 int64_t finalStride = collapsedStrides[groupId];
531 if (ShapedType::isDynamic(finalStride)) {
534 for (int64_t currentDim : reassocGroup) {
535 assert(srcShape[currentDim] == 1 &&
536 "We should be dealing with 1x1x...x1");
538 if (ShapedType::isDynamic(strides[currentDim]))
539 return {origStrides[currentDim]};
541 llvm_unreachable(
"We should have found a dynamic stride");
569 template <
typename ReassociativeReshapeLikeOp,
581 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
585 Location origLoc = reshape.getLoc();
586 Value source = reshape.getSrc();
587 auto sourceType = cast<MemRefType>(source.
getType());
588 unsigned sourceRank = sourceType.getRank();
590 auto newExtractStridedMetadata =
591 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
595 MemRefType reshapeType = reshape.getResultType();
596 unsigned reshapeRank = reshapeType.getRank();
599 ShapedType::isDynamic(offset)
601 : rewriter.getIndexAttr(offset);
604 if (sourceRank == 0) {
606 auto memrefDesc = rewriter.
create<memref::ReinterpretCastOp>(
607 origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
608 offsetOfr, ones, ones);
609 rewriter.
replaceOp(reshape, memrefDesc.getResult());
614 finalSizes.reserve(reshapeRank);
616 finalStrides.reserve(reshapeRank);
623 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
624 for (; idx != endIdx; ++idx) {
626 getReshapedSizes(reshape, rewriter, origSizes, idx);
628 reshape, rewriter, origSizes, origStrides, idx);
630 unsigned groupSize = reshapedSizes.size();
631 for (
unsigned i = 0; i < groupSize; ++i) {
632 finalSizes.push_back(reshapedSizes[i]);
633 finalStrides.push_back(reshapedStrides[i]);
636 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
637 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
638 "We should have visited all the input dimensions");
639 assert(finalSizes.size() == reshapeRank &&
640 "We should have populated all the values");
641 auto memrefDesc = rewriter.
create<memref::ReinterpretCastOp>(
642 origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
643 offsetOfr, finalSizes, finalStrides);
644 rewriter.
replaceOp(reshape, memrefDesc.getResult());
663 template <
typename AllocLikeOp>
664 struct ExtractStridedMetadataOpAllocFolder
669 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
671 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
675 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
676 if (!memRefType.getLayout().isIdentity())
678 allocLikeOp,
"alloc-like operations should have been normalized");
681 int rank = memRefType.getRank();
684 ValueRange dynamic = allocLikeOp.getDynamicSizes();
687 unsigned dynamicPos = 0;
688 for (int64_t size : memRefType.getShape()) {
689 if (ShapedType::isDynamic(size))
690 sizes.push_back(dynamic[dynamicPos++]);
698 unsigned symbolNumber = 0;
699 for (
int i = rank - 2; i >= 0; --i) {
701 assert(i + 1 + symbolNumber == sizes.size() &&
702 "The ArrayRef should encompass the last #symbolNumber sizes");
705 sizesInvolvedInStride);
710 results.reserve(rank * 2 + 2);
712 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
715 results.push_back(
nullptr);
717 if (allocLikeOp.getType() == baseBufferType)
718 results.push_back(allocLikeOp);
720 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
721 loc, baseBufferType, allocLikeOp, offset,
727 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
754 struct ExtractStridedMetadataOpGetGlobalFolder
759 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
761 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
765 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
766 if (!memRefType.getLayout().isIdentity()) {
769 "get-global operation result should have been normalized");
773 int rank = memRefType.getRank();
777 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
778 "unexpected dynamic shape for result of `memref.get_global` op");
785 results.reserve(rank * 2 + 2);
787 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
789 if (getGlobalOp.getType() == baseBufferType)
790 results.push_back(getGlobalOp);
792 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
793 loc, baseBufferType, getGlobalOp, offset,
798 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
800 for (
auto size : sizes)
801 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, size));
803 for (
auto stride : strides)
804 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, stride));
813 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
818 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
821 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
825 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
844 class ExtractStridedMetadataOpReinterpretCastFolder
849 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
851 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
852 .getDefiningOp<memref::ReinterpretCastOp>();
853 if (!reinterpretCastOp)
856 Location loc = extractStridedMetadataOp.getLoc();
859 if (
failed(extractStridedMetadataOp.inferReturnTypes(
860 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
862 inferredReturnTypes)))
864 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
866 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
867 unsigned rank = memrefType.getRank();
869 results.resize_for_overwrite(rank * 2 + 2);
871 auto newExtractStridedMetadata =
872 rewriter.
create<memref::ExtractStridedMetadataOp>(
873 loc, reinterpretCastOp.getSource());
876 results[0] = newExtractStridedMetadata.getBaseBuffer();
880 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
882 const unsigned sizeStartIdx = 2;
883 const unsigned strideStartIdx = sizeStartIdx + rank;
887 for (
unsigned i = 0; i < rank; ++i) {
888 results[sizeStartIdx + i] = sizes[i];
889 results[strideStartIdx + i] = strides[i];
891 rewriter.
replaceOp(extractStridedMetadataOp,
918 class ExtractStridedMetadataOpCastFolder
923 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
925 Value source = extractStridedMetadataOp.getSource();
930 Location loc = extractStridedMetadataOp.getLoc();
933 if (
failed(extractStridedMetadataOp.inferReturnTypes(
934 rewriter.
getContext(), loc, {castOp.getSource()},
936 inferredReturnTypes)))
938 "cast source's type is incompatible");
940 auto memrefType = cast<MemRefType>(source.
getType());
941 unsigned rank = memrefType.getRank();
943 results.resize_for_overwrite(rank * 2 + 2);
945 auto newExtractStridedMetadata =
946 rewriter.
create<memref::ExtractStridedMetadataOp>(loc,
950 results[0] = newExtractStridedMetadata.getBaseBuffer();
952 auto getConstantOrValue = [&rewriter](int64_t constant,
954 return !ShapedType::isDynamic(constant)
960 assert(sourceStrides.size() == rank &&
"unexpected number of strides");
964 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
966 const unsigned sizeStartIdx = 2;
967 const unsigned strideStartIdx = sizeStartIdx + rank;
972 for (
unsigned i = 0; i < rank; ++i) {
973 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
974 results[strideStartIdx + i] =
975 getConstantOrValue(sourceStrides[i], strides[i]);
977 rewriter.
replaceOp(extractStridedMetadataOp,
990 class ExtractStridedMetadataOpExtractStridedMetadataFolder
995 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
997 auto sourceExtractStridedMetadataOp =
998 extractStridedMetadataOp.getSource()
999 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1000 if (!sourceExtractStridedMetadataOp)
1002 Location loc = extractStridedMetadataOp.getLoc();
1003 rewriter.
replaceOp(extractStridedMetadataOp,
1004 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1014 patterns.
add<SubviewFolder,
1015 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1016 getExpandedStrides>,
1017 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1018 getCollapsedStride>,
1019 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1020 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1021 ExtractStridedMetadataOpGetGlobalFolder,
1022 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1023 ExtractStridedMetadataOpReinterpretCastFolder,
1024 ExtractStridedMetadataOpCastFolder,
1025 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1031 patterns.
add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1032 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1033 ExtractStridedMetadataOpGetGlobalFolder,
1034 ExtractStridedMetadataOpSubviewFolder,
1035 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1036 ExtractStridedMetadataOpReinterpretCastFolder,
1037 ExtractStridedMetadataOpCastFolder,
1038 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1048 struct ExpandStridedMetadataPass final
1049 :
public memref::impl::ExpandStridedMetadataBase<
1050 ExpandStridedMetadataPass> {
1051 void runOnOperation()
override;
1056 void ExpandStridedMetadataPass::runOnOperation() {
1063 return std::make_unique<ExpandStridedMetadataPass>();
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
bool use_empty()
Returns true if this operation has no uses.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.