25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41 struct StridedMetadata {
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
78 auto origStrides = newExtractStridedMetadata.getStrides();
86 values[0] = ShapedType::isDynamic(sourceOffset)
88 : rewriter.getIndexAttr(sourceOffset);
93 for (
unsigned i = 0; i < sourceRank; ++i) {
96 ShapedType::isDynamic(sourceStrides[i])
100 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
103 unsigned baseIdxForDim = 1 + 2 * i;
104 unsigned subOffsetForDim = baseIdxForDim;
105 unsigned origStrideForDim = baseIdxForDim + 1;
106 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
107 values[subOffsetForDim] = subOffsets[i];
108 values[origStrideForDim] = origStride;
118 auto subType = cast<MemRefType>(subview.getType());
119 unsigned subRank = subType.getRank();
128 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
131 finalSizes.reserve(subRank);
134 finalStrides.reserve(subRank);
136 for (
unsigned i = 0; i < sourceRank; ++i) {
137 if (droppedDims.test(i))
140 finalSizes.push_back(subSizes[i]);
141 finalStrides.push_back(strides[i]);
143 assert(finalSizes.size() == subRank &&
144 "Should have populated all the values at this point");
145 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
146 finalSizes, finalStrides};
170 resolveSubviewStridedMetadata(rewriter, subview);
171 if (
failed(stridedMetadata)) {
173 "failed to resolve subview metadata");
177 subview, subview.getType(), stridedMetadata->basePtr,
178 stridedMetadata->offset, stridedMetadata->sizes,
179 stridedMetadata->strides);
197 struct ExtractStridedMetadataOpSubviewFolder
201 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
203 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
208 resolveSubviewStridedMetadata(rewriter, subviewOp);
209 if (
failed(stridedMetadata)) {
211 op,
"failed to resolve metadata in terms of source subview op");
215 results.reserve(subviewOp.getType().getRank() * 2 + 2);
216 results.push_back(stridedMetadata->basePtr);
218 stridedMetadata->offset));
222 stridedMetadata->strides));
245 getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
248 expandShape.getReassociationIndices()[groupId];
249 assert(!reassocGroup.empty() &&
250 "Reassociation group should have at least one dimension");
252 unsigned groupSize = reassocGroup.size();
255 uint64_t productOfAllStaticSizes = 1;
256 std::optional<unsigned> dynSizeIdx;
257 MemRefType expandShapeType = expandShape.getResultType();
260 for (
unsigned i = 0; i < groupSize; ++i) {
261 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
262 if (ShapedType::isDynamic(dimSize)) {
263 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
267 productOfAllStaticSizes *= dimSize;
277 builder, expandShape.getLoc(), s0.
floorDiv(productOfAllStaticSizes),
281 return expandedSizes;
314 expandShape.getReassociationIndices()[groupId];
315 assert(!reassocGroup.empty() &&
316 "Reassociation group should have at least one dimension");
318 unsigned groupSize = reassocGroup.size();
319 MemRefType expandShapeType = expandShape.getResultType();
321 std::optional<int64_t> dynSizeIdx;
325 uint64_t currentStride = 1;
327 for (
int i = groupSize - 1; i >= 0; --i) {
328 expandedStrides[i] = builder.
getIndexAttr(currentStride);
329 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
330 if (ShapedType::isDynamic(dimSize)) {
331 assert(!dynSizeIdx &&
"There must be at most one dynamic size per group");
336 currentStride *= dimSize;
340 Value source = expandShape.getSrc();
341 auto sourceType = cast<MemRefType>(source.
getType());
344 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
345 ? origStrides[groupId]
349 int64_t doneStrideIdx = 0;
353 int64_t productOfAllStaticSizes = currentStride;
354 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
355 "We shouldn't be able to change dynamicity");
360 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
361 int64_t baseExpandedStride =
362 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
365 builder, expandShape.getLoc(),
366 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
367 {origSize, origStride});
373 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
374 int64_t baseExpandedStride =
375 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
378 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
381 return expandedStrides;
398 unsigned numberOfSymbols = 0;
399 unsigned groupSize = indices.size();
400 for (
unsigned i = 0; i < groupSize; ++i) {
403 unsigned srcIdx = indices[i];
404 int64_t maybeConstant = maybeConstants[srcIdx];
406 inputValues.push_back(isDynamic(maybeConstant)
430 getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
434 MemRefType collapseShapeType = collapseShape.getResultType();
436 uint64_t size = collapseShapeType.getDimSize(groupId);
437 if (!ShapedType::isDynamic(size)) {
439 return collapsedSize;
445 Value source = collapseShape.getSrc();
446 auto sourceType = cast<MemRefType>(source.
getType());
449 collapseShape.getReassociationIndices()[groupId];
451 collapsedSize.push_back(getProductOfValues(
452 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
453 origSizes, ShapedType::isDynamic));
455 return collapsedSize;
471 getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
475 collapseShape.getReassociationIndices()[groupId];
476 assert(!reassocGroup.empty() &&
477 "Reassociation group should have at least one dimension");
479 Value source = collapseShape.getSrc();
480 auto sourceType = cast<MemRefType>(source.
getType());
486 for (int64_t currentDim : reassocGroup) {
492 if (srcShape[currentDim] == 1)
495 int64_t currentStride = strides[currentDim];
496 groupStrides.push_back(ShapedType::isDynamic(currentStride)
497 ? origStrides[currentDim]
500 if (groupStrides.empty()) {
503 MemRefType collapsedType = collapseShape.getResultType();
504 auto [collapsedStrides, collapsedOffset] =
506 int64_t finalStride = collapsedStrides[groupId];
507 if (ShapedType::isDynamic(finalStride)) {
510 for (int64_t currentDim : reassocGroup) {
511 assert(srcShape[currentDim] == 1 &&
512 "We should be dealing with 1x1x...x1");
514 if (ShapedType::isDynamic(strides[currentDim]))
515 return {origStrides[currentDim]};
517 llvm_unreachable(
"We should have found a dynamic stride");
545 template <
typename ReassociativeReshapeLikeOp,
557 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
561 Location origLoc = reshape.getLoc();
562 Value source = reshape.getSrc();
563 auto sourceType = cast<MemRefType>(source.
getType());
564 unsigned sourceRank = sourceType.getRank();
566 auto newExtractStridedMetadata =
567 rewriter.
create<memref::ExtractStridedMetadataOp>(origLoc, source);
571 MemRefType reshapeType = reshape.getResultType();
572 unsigned reshapeRank = reshapeType.getRank();
575 ShapedType::isDynamic(offset)
577 : rewriter.getIndexAttr(offset);
580 if (sourceRank == 0) {
582 auto memrefDesc = rewriter.
create<memref::ReinterpretCastOp>(
583 origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
584 offsetOfr, ones, ones);
585 rewriter.
replaceOp(reshape, memrefDesc.getResult());
590 finalSizes.reserve(reshapeRank);
592 finalStrides.reserve(reshapeRank);
599 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
600 for (; idx != endIdx; ++idx) {
602 getReshapedSizes(reshape, rewriter, origSizes, idx);
604 reshape, rewriter, origSizes, origStrides, idx);
606 unsigned groupSize = reshapedSizes.size();
607 for (
unsigned i = 0; i < groupSize; ++i) {
608 finalSizes.push_back(reshapedSizes[i]);
609 finalStrides.push_back(reshapedStrides[i]);
612 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
613 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
614 "We should have visited all the input dimensions");
615 assert(finalSizes.size() == reshapeRank &&
616 "We should have populated all the values");
617 auto memrefDesc = rewriter.
create<memref::ReinterpretCastOp>(
618 origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
619 offsetOfr, finalSizes, finalStrides);
620 rewriter.
replaceOp(reshape, memrefDesc.getResult());
639 template <
typename AllocLikeOp>
640 struct ExtractStridedMetadataOpAllocFolder
645 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
647 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
651 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
652 if (!memRefType.getLayout().isIdentity())
654 allocLikeOp,
"alloc-like operations should have been normalized");
657 int rank = memRefType.getRank();
660 ValueRange dynamic = allocLikeOp.getDynamicSizes();
663 unsigned dynamicPos = 0;
664 for (int64_t size : memRefType.getShape()) {
665 if (ShapedType::isDynamic(size))
666 sizes.push_back(dynamic[dynamicPos++]);
674 unsigned symbolNumber = 0;
675 for (
int i = rank - 2; i >= 0; --i) {
677 assert(i + 1 + symbolNumber == sizes.size() &&
678 "The ArrayRef should encompass the last #symbolNumber sizes");
681 sizesInvolvedInStride);
686 results.reserve(rank * 2 + 2);
688 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
691 results.push_back(
nullptr);
693 if (allocLikeOp.getType() == baseBufferType)
694 results.push_back(allocLikeOp);
696 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
697 loc, baseBufferType, allocLikeOp, offset,
703 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
730 struct ExtractStridedMetadataOpGetGlobalFolder
735 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
737 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
741 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
742 if (!memRefType.getLayout().isIdentity()) {
745 "get-global operation result should have been normalized");
749 int rank = memRefType.getRank();
753 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
754 "unexpected dynamic shape for result of `memref.get_global` op");
761 results.reserve(rank * 2 + 2);
763 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
765 if (getGlobalOp.getType() == baseBufferType)
766 results.push_back(getGlobalOp);
768 results.push_back(rewriter.
create<memref::ReinterpretCastOp>(
769 loc, baseBufferType, getGlobalOp, offset,
774 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, offset));
776 for (
auto size : sizes)
777 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, size));
779 for (
auto stride : strides)
780 results.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, stride));
789 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
794 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
797 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
801 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
820 class ExtractStridedMetadataOpReinterpretCastFolder
825 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
827 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
828 .getDefiningOp<memref::ReinterpretCastOp>();
829 if (!reinterpretCastOp)
832 Location loc = extractStridedMetadataOp.getLoc();
835 if (
failed(extractStridedMetadataOp.inferReturnTypes(
836 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
838 inferredReturnTypes)))
840 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
842 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
843 unsigned rank = memrefType.getRank();
845 results.resize_for_overwrite(rank * 2 + 2);
847 auto newExtractStridedMetadata =
848 rewriter.
create<memref::ExtractStridedMetadataOp>(
849 loc, reinterpretCastOp.getSource());
852 results[0] = newExtractStridedMetadata.getBaseBuffer();
856 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
858 const unsigned sizeStartIdx = 2;
859 const unsigned strideStartIdx = sizeStartIdx + rank;
863 for (
unsigned i = 0; i < rank; ++i) {
864 results[sizeStartIdx + i] = sizes[i];
865 results[strideStartIdx + i] = strides[i];
867 rewriter.
replaceOp(extractStridedMetadataOp,
894 class ExtractStridedMetadataOpCastFolder
899 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
901 Value source = extractStridedMetadataOp.getSource();
906 Location loc = extractStridedMetadataOp.getLoc();
909 if (
failed(extractStridedMetadataOp.inferReturnTypes(
910 rewriter.
getContext(), loc, {castOp.getSource()},
912 inferredReturnTypes)))
914 "cast source's type is incompatible");
916 auto memrefType = cast<MemRefType>(source.
getType());
917 unsigned rank = memrefType.getRank();
919 results.resize_for_overwrite(rank * 2 + 2);
921 auto newExtractStridedMetadata =
922 rewriter.
create<memref::ExtractStridedMetadataOp>(loc,
926 results[0] = newExtractStridedMetadata.getBaseBuffer();
928 auto getConstantOrValue = [&rewriter](int64_t constant,
930 return !ShapedType::isDynamic(constant)
936 assert(sourceStrides.size() == rank &&
"unexpected number of strides");
940 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
942 const unsigned sizeStartIdx = 2;
943 const unsigned strideStartIdx = sizeStartIdx + rank;
948 for (
unsigned i = 0; i < rank; ++i) {
949 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
950 results[strideStartIdx + i] =
951 getConstantOrValue(sourceStrides[i], strides[i]);
953 rewriter.
replaceOp(extractStridedMetadataOp,
966 class ExtractStridedMetadataOpExtractStridedMetadataFolder
971 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
973 auto sourceExtractStridedMetadataOp =
974 extractStridedMetadataOp.getSource()
975 .getDefiningOp<memref::ExtractStridedMetadataOp>();
976 if (!sourceExtractStridedMetadataOp)
978 Location loc = extractStridedMetadataOp.getLoc();
979 rewriter.
replaceOp(extractStridedMetadataOp,
980 {sourceExtractStridedMetadataOp.getBaseBuffer(),
990 patterns.
add<SubviewFolder,
991 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
993 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
995 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
996 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
997 ExtractStridedMetadataOpGetGlobalFolder,
998 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
999 ExtractStridedMetadataOpReinterpretCastFolder,
1000 ExtractStridedMetadataOpCastFolder,
1001 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1007 patterns.
add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1008 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1009 ExtractStridedMetadataOpGetGlobalFolder,
1010 ExtractStridedMetadataOpSubviewFolder,
1011 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1012 ExtractStridedMetadataOpReinterpretCastFolder,
1013 ExtractStridedMetadataOpCastFolder,
1014 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1024 struct ExpandStridedMetadataPass final
1025 :
public memref::impl::ExpandStridedMetadataBase<
1026 ExpandStridedMetadataPass> {
1027 void runOnOperation()
override;
1032 void ExpandStridedMetadataPass::runOnOperation() {
1039 return std::make_unique<ExpandStridedMetadataPass>();
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
bool use_empty()
Returns true if this operation has no uses.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...