15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/LogicalResult.h"
23 struct FoldExpandOfRankReducingExtract
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29 RankedTensorType resultType = expandShapeOp.getResultType();
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticOffsets(),
41 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42 if (nonReducingExtractType != resultType)
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
57 struct FoldUnPaddingCollapseIntoExtract
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
77 "expected unpadding collapse");
79 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
89 template <
typename OpTy>
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
104 RankedTensorType nonReducingInsertType =
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
122 template <
typename OpTy>
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
139 "expected rank increasing expansion");
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
150 struct BubbleUpExpandThroughParallelCollapse
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
165 if (expandReInds.size() == 0) {
176 for (
auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(expandReInds, collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
185 collapseReassociation.front(), collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
189 expandReassociation.front(), expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
215 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back({newCollapseIndex++});
228 newExpandReInds.push_back({newExpandIndex++});
229 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
238 if (collapseReassociation.size() != 1) {
240 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(newCollapseIndex++);
242 newExpandReInds.push_back({newExpandIndex++});
243 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
246 newCollapseReInds.push_back(newCollapseReassociation);
254 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(newExpandIndex++);
256 newCollapseReInds.push_back({newCollapseIndex++});
257 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
259 newExpandReInds.push_back(newExpandReassociation);
267 auto expandResultType = expandOp.getResultType().clone(staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = tensor::ExpandShapeOp::create(
274 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
281 Value replacement = newCollapseSrc;
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
283 replacement = tensor::CollapseShapeOp::create(
284 rewriter, loc, newCollapseSrc, newCollapseReInds);
286 rewriter.
replaceOp(expandOp, replacement);
322 struct BubbleUpExpandShapeThroughExtractSlice
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330 if (!expandShapeOp) {
332 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
335 expandShapeOp.getReassociationIndices();
339 offsets, sizes, strides)))
344 RankedTensorType resultType = sliceOp.getResultType();
348 Value newSliceOp = tensor::ExtractSliceOp::create(
349 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
351 sliceOp, resultType, newSliceOp,
352 expandShapeOp.getReassociationIndices(), expandedSizes);
430 struct BubbleUpCollapseShapeThroughExtractSlice
434 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
436 auto collapseShapeOp =
437 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
438 if (!collapseShapeOp) {
441 "tensor.extract_slice source not produced by tensor.collapse_shape");
446 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
447 collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
450 Value newSliceOp = tensor::ExtractSliceOp::create(
451 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
454 sliceOp, sliceOp.getResultType(), newSliceOp,
455 collapseShapeOp.getReassociationIndices());
464 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
469 if (!sliceOp.hasUnitStride()) {
476 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
485 FailureOr<bool> maybeEqual =
487 return llvm::succeeded(maybeEqual) && maybeEqual.value();
502 int64_t e = indices.size();
516 int64_t expandedDim = indices[i];
517 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
557 for (
long expandedDim : indices) {
561 reassocGroupSizes.push_back(expandedShape[expandedDim]);
562 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
563 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
567 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
570 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
571 b, loc, offsetVals, reassocGroupSizes,
574 collapsedOffsets.push_back(collapsedOffset);
575 collapsedSizes.push_back(collapsedSize);
584 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
590 if (!sliceOp.hasUnitStride()) {
601 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602 collapsedSizes.size()) {
610 expandedStrides.resize(expandedShape.size(), b.
getIndexAttr(1));
611 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
612 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
617 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
618 int nonUnitSizeCount = 0;
619 for (int64_t expandedShapeIdx : reassocIndices) {
620 if (expandedShape[expandedShapeIdx] != 1) {
622 expandedSizes.push_back(collapsedSize);
623 expandedOffsets.push_back(collapsedOffset);
631 if (nonUnitSizeCount != 1) {
654 int64_t currentCollapsedOffset =
658 reassocIndices.rend());
660 int64_t reassocGroupSize = reassocIndices.size();
664 for (; idx < reassocGroupSize; ++idx) {
665 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
667 if (currentCollapsedsize < expandedShapeSize)
672 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
673 (currentCollapsedOffset % expandedShapeSize) != 0) {
677 groupExpandedSizes.push_back(b.
getIndexAttr(expandedShapeSize));
680 currentCollapsedsize /= expandedShapeSize;
681 currentCollapsedOffset /= expandedShapeSize;
685 if (idx < reassocGroupSize) {
686 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
687 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
690 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
693 groupExpandedSizes.push_back(b.
getIndexAttr(currentCollapsedsize));
694 groupExpandedOffsets.push_back(b.
getIndexAttr(offsetInDim));
695 currentCollapsedOffset /= expandedShapeSize;
705 for (idx++; idx < reassocGroupSize; ++idx) {
706 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
707 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
709 groupExpandedOffsets.push_back(b.
getIndexAttr(offsetInDim));
710 currentCollapsedOffset /= expandedShapeSize;
712 expandedSizes.append(groupExpandedSizes.rbegin(),
713 groupExpandedSizes.rend());
714 expandedOffsets.append(groupExpandedOffsets.rbegin(),
715 groupExpandedOffsets.rend());
723 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
724 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
725 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
726 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
727 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
738 patterns.add<BubbleUpExpandShapeThroughExtractSlice,
739 BubbleUpCollapseShapeThroughExtractSlice>(
patterns.getContext());
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...