16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/LogicalResult.h"
24struct FoldExpandOfRankReducingExtract
26 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
28 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29 PatternRewriter &rewriter)
const override {
30 RankedTensorType resultType = expandShapeOp.getResultType();
32 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
35 RankedTensorType srcType = extractSliceOp.getSourceType();
40 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41 srcType, extractSliceOp.getStaticSizes());
42 if (nonReducingExtractType != resultType)
45 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
57struct FoldUnPaddingCollapseIntoExtract
59 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62 PatternRewriter &rewriter)
const override {
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
75 if (res != SliceVerificationResult::Success)
77 "expected unpadding collapse");
79 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
89template <
typename OpTy>
91 using OpRewritePattern<OpTy>::OpRewritePattern;
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
94 PatternRewriter &rewriter)
const override {
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
104 RankedTensorType nonReducingInsertType =
105 RankedTensorType::get(insertSliceOp.getStaticSizes(),
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
110 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
122template <
typename OpTy>
124 using OpRewritePattern<OpTy>::OpRewritePattern;
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
127 PatternRewriter &rewriter)
const override {
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
137 if (res != SliceVerificationResult::Success)
139 "expected rank increasing expansion");
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
150struct BubbleUpExpandThroughParallelCollapse
152 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155 PatternRewriter &rewriter)
const override {
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
165 if (expandReInds.size() == 0) {
174 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176 for (
auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(expandReInds, collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
184 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185 collapseReassociation.front(), collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
188 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189 expandReassociation.front(), expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
205 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206 Location loc = expandOp->getLoc();
207 SmallVector<OpFoldResult> sourceSizes =
209 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210 SmallVector<OpFoldResult> newExpandSizes;
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
215 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back({newCollapseIndex++});
228 newExpandReInds.push_back({newExpandIndex++});
229 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
238 if (collapseReassociation.size() != 1) {
240 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(newCollapseIndex++);
242 newExpandReInds.push_back({newExpandIndex++});
243 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
246 newCollapseReInds.push_back(newCollapseReassociation);
254 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(newExpandIndex++);
256 newCollapseReInds.push_back({newCollapseIndex++});
257 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
259 newExpandReInds.push_back(newExpandReassociation);
264 SmallVector<Value> dynamicSizes;
265 SmallVector<int64_t> staticSizes;
267 auto expandResultType = expandOp.getResultType().clone(staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = tensor::ExpandShapeOp::create(
274 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
284 rewriter, loc, newCollapseSrc, newCollapseReInds);
322struct BubbleUpExtractSliceThroughExpandShape
324 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327 PatternRewriter &rewriter)
const override {
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330 if (!expandShapeOp) {
332 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
334 SmallVector<ReassociationIndices> reassociation =
335 expandShapeOp.getReassociationIndices();
337 SmallVector<OpFoldResult> offsets, sizes, strides;
339 offsets, sizes, strides)))
343 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344 RankedTensorType resultType = sliceOp.getResultType();
347 Location loc = sliceOp.getLoc();
348 Value newSliceOp = tensor::ExtractSliceOp::create(
349 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
351 sliceOp, resultType, newSliceOp,
352 expandShapeOp.getReassociationIndices(), expandedSizes);
430struct BubbleUpExtractSliceThroughCollapseShape
432 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
434 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
435 PatternRewriter &rewriter)
const override {
436 auto collapseShapeOp =
437 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
438 if (!collapseShapeOp) {
441 "tensor.extract_slice source not produced by tensor.collapse_shape");
444 SmallVector<OpFoldResult> offsets, sizes, strides;
446 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
447 collapseShapeOp.getSrc(), offsets, sizes, strides)))
450 Value newSliceOp = tensor::ExtractSliceOp::create(
451 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
454 sliceOp, sliceOp.getResultType(), newSliceOp,
455 collapseShapeOp.getReassociationIndices());
469 if (!sliceOp.hasUnitStride()) {
476 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
485 FailureOr<bool> maybeEqual =
487 return llvm::succeeded(maybeEqual) && maybeEqual.value();
517 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
557 for (
long expandedDim :
indices) {
561 reassocGroupSizes.push_back(expandedShape[expandedDim]);
562 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
563 collapsedSize =
mul(collapsedSize, expandedSizes[expandedDim]);
567 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
570 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
571 b, loc, offsetVals, reassocGroupSizes,
574 collapsedOffsets.push_back(collapsedOffset);
575 collapsedSizes.push_back(collapsedSize);
578 collapsedStrides.push_back(
b.getIndexAttr(1));
588 if (staticValue.has_value())
589 return staticValue.value() % factor == 0;
591 Value value = dyn_cast<Value>(ofr);
616 assert(groupSizes.empty() &&
"Group sizes must be empty");
622 int nonUnitSizeCount = llvm::count_if(
623 reassocIndices, [&expandedShape](
int64_t expandedShapeIdx) {
624 return expandedShape[expandedShapeIdx] != 1;
626 if (nonUnitSizeCount == 1) {
627 for (
int64_t expandedShapeIdx : reassocIndices) {
628 if (expandedShape[expandedShapeIdx] != 1)
629 groupSizes.push_back(collapsedSize);
631 groupSizes.push_back(
b.getIndexAttr(1));
638 if (isa<Value>(collapsedSize))
642 assert(staticSize.has_value() &&
"Expected static size");
647 if (staticSize.value() == 1) {
648 for (
size_t i = 0; i < reassocIndices.size(); ++i)
649 groupSizes.push_back(
b.getIndexAttr(1));
672 assert(staticSize.value() > 1 &&
"Expected size to be greater than 1");
673 int64_t currentCollapsedsize = staticSize.value();
674 int64_t currentOffsetDivisor = 1;
677 reassocIndices.rend());
679 int64_t reassocGroupSize = reassocIndices.size();
683 for (; idx < reassocGroupSize; ++idx) {
684 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
685 if (expandedShapeSize == ShapedType::kDynamic)
688 if (currentCollapsedsize < expandedShapeSize)
692 if ((currentCollapsedsize % expandedShapeSize) != 0)
696 currentOffsetDivisor *= expandedShapeSize;
697 if (!
isMultipleOf(collapsedOffset, currentOffsetDivisor))
701 groupSizes.push_back(
b.getIndexAttr(expandedShapeSize));
702 currentCollapsedsize /= expandedShapeSize;
706 if (idx < reassocGroupSize) {
707 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
710 if (staticOffset.has_value()) {
713 (staticOffset.value() / currentOffsetDivisor) % expandedShapeSize;
714 if ((currentCollapsedsize + offsetInDim) > expandedShapeSize)
723 if ((expandedShapeSize % currentCollapsedsize) != 0)
729 groupSizes.push_back(
b.getIndexAttr(currentCollapsedsize));
739 for (idx++; idx < reassocGroupSize; ++idx)
740 groupSizes.push_back(
b.getIndexAttr(1));
743 groupSizes = llvm::to_vector(llvm::reverse(groupSizes));
753 if (!sliceOp.hasUnitStride()) {
764 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
765 collapsedSizes.size()) {
774 cast<RankedTensorType>(expandedValue.
getType()).getShape();
776 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
777 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
781 b, collapsedSize, collapsedOffset, reassocIndices, expandedShape,
785 groupResults.emplace_back(groupSizes);
788 expandedStrides.resize(expandedShape.size(),
b.getIndexAttr(1));
789 for (
auto [groupIdx, reassocIndices] : llvm::enumerate(reassociation)) {
790 auto &sizes = groupResults[groupIdx];
791 expandedSizes.append(sizes);
794 for (
int64_t expandedShapeIdx : reassocIndices)
798 OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
801 auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
802 b, sliceOp.getLoc(), offsetVal, basis,
true);
804 expandedOffsets.push_back(
result);
812 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
813 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
814 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
815 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
816 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
827 patterns.add<BubbleUpExtractSliceThroughExpandShape,
828 BubbleUpExtractSliceThroughCollapseShape>(
patterns.getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static LogicalResult computeExpandedSliceInfoForReassocGroup(OpBuilder &b, OpFoldResult collapsedSize, OpFoldResult collapsedOffset, const ReassociationIndices &reassocIndices, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &groupSizes)
Given a collapsedOffset and collapsedSize, this function validates that the slice is representable as...
static bool isMultipleOf(OpFoldResult ofr, int64_t factor)
Base type for affine expression.
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This is a value defined by a result of an operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, Value expandedValue, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...