15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/LogicalResult.h"
23struct FoldExpandOfRankReducingExtract
25 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28 PatternRewriter &rewriter)
const override {
29 RankedTensorType resultType = expandShapeOp.getResultType();
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticSizes());
41 if (nonReducingExtractType != resultType)
44 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
45 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
46 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
48 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
56struct FoldUnPaddingCollapseIntoExtract
58 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
60 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
61 PatternRewriter &rewriter)
const override {
63 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
67 if (!extractSliceOp || !extractSliceOp->hasOneUse())
73 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
74 if (res != SliceVerificationResult::Success)
76 "expected unpadding collapse");
78 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
79 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
80 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
81 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
82 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
88template <
typename OpTy>
90 using OpRewritePattern<OpTy>::OpRewritePattern;
92 LogicalResult matchAndRewrite(OpTy insertSliceOp,
93 PatternRewriter &rewriter)
const override {
94 auto collapseShapeOp =
95 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
98 RankedTensorType srcType = collapseShapeOp.getSrcType();
103 RankedTensorType nonReducingInsertType =
104 RankedTensorType::get(insertSliceOp.getStaticSizes(),
105 insertSliceOp.getDestType().getElementType());
106 if (nonReducingInsertType != srcType)
109 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
110 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
111 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
113 insertSliceOp.getDest(), mixedOffsets,
114 mixedSizes, mixedStrides);
121template <
typename OpTy>
123 using OpRewritePattern<OpTy>::OpRewritePattern;
125 LogicalResult matchAndRewrite(OpTy insertSliceOp,
126 PatternRewriter &rewriter)
const override {
127 auto expandShapeOp = insertSliceOp.getSource()
128 .template getDefiningOp<tensor::ExpandShapeOp>();
135 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
136 if (res != SliceVerificationResult::Success)
138 "expected rank increasing expansion");
141 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
149struct BubbleUpExpandThroughParallelCollapse
151 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
153 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
154 PatternRewriter &rewriter)
const override {
156 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
159 auto expandReInds = expandOp.getReassociationIndices();
160 auto collapseReInds = collapseOp.getReassociationIndices();
164 if (expandReInds.size() == 0) {
173 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
174 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
175 for (
auto [expandReassociation, collapseReassociation] :
176 llvm::zip_equal(expandReInds, collapseReInds)) {
177 if (collapseReassociation.size() == expandReassociation.size()) {
183 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
184 collapseReassociation.front(), collapseReassociation.size());
185 int64_t numCollapsedDynamic =
186 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
187 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
188 expandReassociation.front(), expandReassociation.size());
189 int64_t numExpandedDynamic =
190 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
191 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
192 collapsedStaticShapes != expandedStaticShapes) {
199 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
204 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
205 Location loc = expandOp->getLoc();
206 SmallVector<OpFoldResult> sourceSizes =
208 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
209 SmallVector<OpFoldResult> newExpandSizes;
211 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
214 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
215 auto &collapseReassociation = collapseReInds[idx];
216 auto &expandReassociation = expandReInds[idx];
224 if (collapseReassociation.size() == expandReassociation.size()) {
225 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
226 newCollapseReInds.push_back({newCollapseIndex++});
227 newExpandReInds.push_back({newExpandIndex++});
228 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
237 if (collapseReassociation.size() != 1) {
239 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
240 newCollapseReassociation.push_back(newCollapseIndex++);
241 newExpandReInds.push_back({newExpandIndex++});
242 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
245 newCollapseReInds.push_back(newCollapseReassociation);
253 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
254 newExpandReassociation.push_back(newExpandIndex++);
255 newCollapseReInds.push_back({newCollapseIndex++});
256 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
258 newExpandReInds.push_back(newExpandReassociation);
263 SmallVector<Value> dynamicSizes;
264 SmallVector<int64_t> staticSizes;
266 auto expandResultType = expandOp.getResultType().clone(staticSizes);
267 Value newCollapseSrc = collapseOp.getSrc();
271 if (newExpandReInds.size() != newExpandSizes.size()) {
272 newCollapseSrc = tensor::ExpandShapeOp::create(
273 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
281 if (newCollapseReInds.size() != newExpandSizes.size()) {
283 rewriter, loc, newCollapseSrc, newCollapseReInds);
321struct BubbleUpExtractSliceThroughExpandShape
323 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
325 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
326 PatternRewriter &rewriter)
const override {
328 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
329 if (!expandShapeOp) {
331 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
333 SmallVector<ReassociationIndices> reassociation =
334 expandShapeOp.getReassociationIndices();
336 SmallVector<OpFoldResult> offsets, sizes, strides;
338 offsets, sizes, strides)))
342 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
343 RankedTensorType resultType = sliceOp.getResultType();
346 Location loc = sliceOp.getLoc();
347 Value newSliceOp = tensor::ExtractSliceOp::create(
348 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
350 sliceOp, resultType, newSliceOp,
351 expandShapeOp.getReassociationIndices(), expandedSizes);
429struct BubbleUpExtractSliceThroughCollapseShape
431 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
433 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
434 PatternRewriter &rewriter)
const override {
435 auto collapseShapeOp =
436 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
437 if (!collapseShapeOp) {
440 "tensor.extract_slice source not produced by tensor.collapse_shape");
443 SmallVector<OpFoldResult> offsets, sizes, strides;
445 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
446 collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
449 Value newSliceOp = tensor::ExtractSliceOp::create(
450 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
453 sliceOp, sliceOp.getResultType(), newSliceOp,
454 collapseShapeOp.getReassociationIndices());
468 if (!sliceOp.hasUnitStride()) {
475 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
484 FailureOr<bool> maybeEqual =
486 return llvm::succeeded(maybeEqual) && maybeEqual.value();
516 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
556 for (
long expandedDim :
indices) {
560 reassocGroupSizes.push_back(expandedShape[expandedDim]);
561 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
562 collapsedSize =
mul(collapsedSize, expandedSizes[expandedDim]);
566 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
569 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
570 b, loc, offsetVals, reassocGroupSizes,
573 collapsedOffsets.push_back(collapsedOffset);
574 collapsedSizes.push_back(collapsedSize);
577 collapsedStrides.push_back(
b.getIndexAttr(1));
589 if (!sliceOp.hasUnitStride()) {
600 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
601 collapsedSizes.size()) {
609 expandedStrides.resize(expandedShape.size(),
b.getIndexAttr(1));
610 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
611 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
616 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
617 int nonUnitSizeCount = 0;
618 for (
int64_t expandedShapeIdx : reassocIndices) {
619 if (expandedShape[expandedShapeIdx] != 1) {
621 expandedSizes.push_back(collapsedSize);
622 expandedOffsets.push_back(collapsedOffset);
626 expandedSizes.push_back(
b.getIndexAttr(1));
627 expandedOffsets.push_back(
b.getIndexAttr(0));
630 if (nonUnitSizeCount != 1) {
653 int64_t currentCollapsedOffset =
657 reassocIndices.rend());
659 int64_t reassocGroupSize = reassocIndices.size();
663 for (; idx < reassocGroupSize; ++idx) {
664 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
666 if (currentCollapsedsize < expandedShapeSize)
671 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
672 (currentCollapsedOffset % expandedShapeSize) != 0) {
676 groupExpandedSizes.push_back(
b.getIndexAttr(expandedShapeSize));
677 groupExpandedOffsets.push_back(
b.getIndexAttr(0));
679 currentCollapsedsize /= expandedShapeSize;
680 currentCollapsedOffset /= expandedShapeSize;
684 if (idx < reassocGroupSize) {
685 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
686 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
689 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
692 groupExpandedSizes.push_back(
b.getIndexAttr(currentCollapsedsize));
693 groupExpandedOffsets.push_back(
b.getIndexAttr(offsetInDim));
694 currentCollapsedOffset /= expandedShapeSize;
704 for (idx++; idx < reassocGroupSize; ++idx) {
705 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
706 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
707 groupExpandedSizes.push_back(
b.getIndexAttr(1));
708 groupExpandedOffsets.push_back(
b.getIndexAttr(offsetInDim));
709 currentCollapsedOffset /= expandedShapeSize;
711 expandedSizes.append(groupExpandedSizes.rbegin(),
712 groupExpandedSizes.rend());
713 expandedOffsets.append(groupExpandedOffsets.rbegin(),
714 groupExpandedOffsets.rend());
722 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
723 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
724 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
725 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
726 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
737 patterns.add<BubbleUpExtractSliceThroughExpandShape,
738 BubbleUpExtractSliceThroughCollapseShape>(
patterns.getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...