15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/LogicalResult.h"
23struct FoldExpandOfRankReducingExtract
25 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28 PatternRewriter &rewriter)
const override {
29 RankedTensorType resultType = expandShapeOp.getResultType();
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticOffsets(),
41 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42 if (nonReducingExtractType != resultType)
45 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
57struct FoldUnPaddingCollapseIntoExtract
59 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62 PatternRewriter &rewriter)
const override {
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
75 if (res != SliceVerificationResult::Success)
77 "expected unpadding collapse");
79 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
89template <
typename OpTy>
91 using OpRewritePattern<OpTy>::OpRewritePattern;
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
94 PatternRewriter &rewriter)
const override {
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
104 RankedTensorType nonReducingInsertType =
105 RankedTensorType::get(insertSliceOp.getStaticSizes(),
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
110 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
122template <
typename OpTy>
124 using OpRewritePattern<OpTy>::OpRewritePattern;
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
127 PatternRewriter &rewriter)
const override {
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
137 if (res != SliceVerificationResult::Success)
139 "expected rank increasing expansion");
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
150struct BubbleUpExpandThroughParallelCollapse
152 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155 PatternRewriter &rewriter)
const override {
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
165 if (expandReInds.size() == 0) {
174 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176 for (
auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(expandReInds, collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
184 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185 collapseReassociation.front(), collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
188 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189 expandReassociation.front(), expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
205 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206 Location loc = expandOp->getLoc();
207 SmallVector<OpFoldResult> sourceSizes =
209 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210 SmallVector<OpFoldResult> newExpandSizes;
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
215 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back({newCollapseIndex++});
228 newExpandReInds.push_back({newExpandIndex++});
229 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
238 if (collapseReassociation.size() != 1) {
240 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(newCollapseIndex++);
242 newExpandReInds.push_back({newExpandIndex++});
243 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
246 newCollapseReInds.push_back(newCollapseReassociation);
254 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(newExpandIndex++);
256 newCollapseReInds.push_back({newCollapseIndex++});
257 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
259 newExpandReInds.push_back(newExpandReassociation);
264 SmallVector<Value> dynamicSizes;
265 SmallVector<int64_t> staticSizes;
267 auto expandResultType = expandOp.getResultType().clone(staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = tensor::ExpandShapeOp::create(
274 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
284 rewriter, loc, newCollapseSrc, newCollapseReInds);
322struct BubbleUpExtractSliceThroughExpandShape
324 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327 PatternRewriter &rewriter)
const override {
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330 if (!expandShapeOp) {
332 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
334 SmallVector<ReassociationIndices> reassociation =
335 expandShapeOp.getReassociationIndices();
337 SmallVector<OpFoldResult> offsets, sizes, strides;
339 offsets, sizes, strides)))
343 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344 RankedTensorType resultType = sliceOp.getResultType();
347 Location loc = sliceOp.getLoc();
348 Value newSliceOp = tensor::ExtractSliceOp::create(
349 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
351 sliceOp, resultType, newSliceOp,
352 expandShapeOp.getReassociationIndices(), expandedSizes);
430struct BubbleUpExtractSliceThroughCollapseShape
432 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
434 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
435 PatternRewriter &rewriter)
const override {
436 auto collapseShapeOp =
437 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
438 if (!collapseShapeOp) {
441 "tensor.extract_slice source not produced by tensor.collapse_shape");
444 SmallVector<OpFoldResult> offsets, sizes, strides;
446 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
447 collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
450 Value newSliceOp = tensor::ExtractSliceOp::create(
451 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
454 sliceOp, sliceOp.getResultType(), newSliceOp,
455 collapseShapeOp.getReassociationIndices());
469 if (!sliceOp.hasUnitStride()) {
476 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
485 FailureOr<bool> maybeEqual =
487 return llvm::succeeded(maybeEqual) && maybeEqual.value();
517 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
557 for (
long expandedDim :
indices) {
561 reassocGroupSizes.push_back(expandedShape[expandedDim]);
562 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
563 collapsedSize =
mul(collapsedSize, expandedSizes[expandedDim]);
567 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
570 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
571 b, loc, offsetVals, reassocGroupSizes,
574 collapsedOffsets.push_back(collapsedOffset);
575 collapsedSizes.push_back(collapsedSize);
578 collapsedStrides.push_back(
b.getIndexAttr(1));
590 if (!sliceOp.hasUnitStride()) {
601 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602 collapsedSizes.size()) {
610 expandedStrides.resize(expandedShape.size(),
b.getIndexAttr(1));
611 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
612 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
617 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
618 int nonUnitSizeCount = 0;
619 for (
int64_t expandedShapeIdx : reassocIndices) {
620 if (expandedShape[expandedShapeIdx] != 1) {
622 expandedSizes.push_back(collapsedSize);
623 expandedOffsets.push_back(collapsedOffset);
627 expandedSizes.push_back(
b.getIndexAttr(1));
628 expandedOffsets.push_back(
b.getIndexAttr(0));
631 if (nonUnitSizeCount != 1) {
654 int64_t currentCollapsedOffset =
658 reassocIndices.rend());
660 int64_t reassocGroupSize = reassocIndices.size();
664 for (; idx < reassocGroupSize; ++idx) {
665 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
667 if (currentCollapsedsize < expandedShapeSize)
672 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
673 (currentCollapsedOffset % expandedShapeSize) != 0) {
677 groupExpandedSizes.push_back(
b.getIndexAttr(expandedShapeSize));
678 groupExpandedOffsets.push_back(
b.getIndexAttr(0));
680 currentCollapsedsize /= expandedShapeSize;
681 currentCollapsedOffset /= expandedShapeSize;
685 if (idx < reassocGroupSize) {
686 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
687 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
690 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
693 groupExpandedSizes.push_back(
b.getIndexAttr(currentCollapsedsize));
694 groupExpandedOffsets.push_back(
b.getIndexAttr(offsetInDim));
695 currentCollapsedOffset /= expandedShapeSize;
705 for (idx++; idx < reassocGroupSize; ++idx) {
706 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
707 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
708 groupExpandedSizes.push_back(
b.getIndexAttr(1));
709 groupExpandedOffsets.push_back(
b.getIndexAttr(offsetInDim));
710 currentCollapsedOffset /= expandedShapeSize;
712 expandedSizes.append(groupExpandedSizes.rbegin(),
713 groupExpandedSizes.rend());
714 expandedOffsets.append(groupExpandedOffsets.rbegin(),
715 groupExpandedOffsets.rend());
723 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
724 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
725 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
726 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
727 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
738 patterns.add<BubbleUpExtractSliceThroughExpandShape,
739 BubbleUpExtractSliceThroughCollapseShape>(
patterns.getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...