15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/LogicalResult.h"
24 struct FoldExpandOfRankReducingExtract
28 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
30 RankedTensorType resultType = expandShapeOp.getResultType();
32 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
35 RankedTensorType srcType = extractSliceOp.getSourceType();
40 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41 srcType, extractSliceOp.getStaticOffsets(),
42 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
43 if (nonReducingExtractType != resultType)
50 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
58 struct FoldUnPaddingCollapseIntoExtract
62 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
65 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
69 if (!extractSliceOp || !extractSliceOp->hasOneUse())
75 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
78 "expected unpadding collapse");
80 Value unPaddedExtractSlice = rewriter.
create<tensor::ExtractSliceOp>(
81 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
82 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
83 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
84 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
90 template <
typename OpTy>
94 LogicalResult matchAndRewrite(OpTy insertSliceOp,
96 auto collapseShapeOp =
97 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
100 RankedTensorType srcType = collapseShapeOp.getSrcType();
105 RankedTensorType nonReducingInsertType =
107 insertSliceOp.getDestType().getElementType());
108 if (nonReducingInsertType != srcType)
115 insertSliceOp.getDest(), mixedOffsets,
116 mixedSizes, mixedStrides);
123 template <
typename OpTy>
127 LogicalResult matchAndRewrite(OpTy insertSliceOp,
129 auto expandShapeOp = insertSliceOp.getSource()
130 .template getDefiningOp<tensor::ExpandShapeOp>();
137 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
140 "expected rank increasing expansion");
143 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
151 struct BubbleUpExpandThroughParallelCollapse
155 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
158 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
161 auto expandReInds = expandOp.getReassociationIndices();
162 auto collapseReInds = collapseOp.getReassociationIndices();
166 if (expandReInds.size() == 0) {
177 for (
auto [expandReassociation, collapseReassociation] :
178 llvm::zip_equal(expandReInds, collapseReInds)) {
179 if (collapseReassociation.size() == expandReassociation.size()) {
186 collapseReassociation.front(), collapseReassociation.size());
187 int64_t numCollapsedDynamic =
188 llvm::count_if(collapsedStaticShapes,
189 [](int64_t d) {
return ShapedType::isDynamic(d); });
191 expandReassociation.front(), expandReassociation.size());
192 int64_t numExpandedDynamic =
193 llvm::count_if(expandedStaticShapes,
194 [](int64_t d) {
return ShapedType::isDynamic(d); });
195 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
196 collapsedStaticShapes != expandedStaticShapes) {
203 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
215 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
218 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
219 auto &collapseReassociation = collapseReInds[idx];
220 auto &expandReassociation = expandReInds[idx];
228 if (collapseReassociation.size() == expandReassociation.size()) {
229 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
230 newCollapseReInds.push_back({newCollapseIndex++});
231 newExpandReInds.push_back({newExpandIndex++});
232 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
241 if (collapseReassociation.size() != 1) {
243 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
244 newCollapseReassociation.push_back(newCollapseIndex++);
245 newExpandReInds.push_back({newExpandIndex++});
246 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
249 newCollapseReInds.push_back(newCollapseReassociation);
257 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
258 newExpandReassociation.push_back(newExpandIndex++);
259 newCollapseReInds.push_back({newCollapseIndex++});
260 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
262 newExpandReInds.push_back(newExpandReassociation);
270 auto expandResultType = expandOp.getResultType().clone(staticSizes);
271 Value newCollapseSrc = collapseOp.getSrc();
275 if (newExpandReInds.size() != newExpandSizes.size()) {
276 newCollapseSrc = rewriter.
create<tensor::ExpandShapeOp>(
277 loc, expandResultType, newCollapseSrc, newExpandReInds,
284 Value replacement = newCollapseSrc;
285 if (newCollapseReInds.size() != newExpandSizes.size()) {
286 replacement = rewriter.
create<tensor::CollapseShapeOp>(
287 loc, newCollapseSrc, newCollapseReInds);
289 rewriter.
replaceOp(expandOp, replacement);
325 struct BubbleUpExpandShapeThroughExtractSlice
329 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
332 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
334 if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
348 expandShapeOp.getOutputShape(), rewriter);
351 Location loc = expandShapeOp->getLoc();
368 expandShapeOp.getReassociationIndices()) {
377 for (
long expandedDim : indices) {
381 reassocGroupSizes.push_back(expandedShape[expandedDim]);
382 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
383 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
387 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
392 .
create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
396 collapsedOffsets.push_back(collapsedOffset);
397 collapsedSizes.push_back(collapsedSize);
408 shape, expandShapeOp.getResultType().getElementType());
411 Value newSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
412 loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
415 sliceOp, resultType, newSliceOp,
416 expandShapeOp.getReassociationIndices(), expandedSizes);
424 checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
425 tensor::ExpandShapeOp expandShapeOp,
428 if (!expandShapeOp) {
430 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
433 if (!sliceOp.hasUnitStride()) {
435 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
436 "be supported in this transformation.");
442 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
445 "unimplemented: rank reducing slice");
450 expandShapeOp.getOutputShape(), rewriter);
453 isZeroOffsetAndFullSize =
457 FailureOr<bool> maybeEqual =
459 return llvm::succeeded(maybeEqual) && maybeEqual.value();
473 expandShapeOp.getReassociationIndices()) {
475 int64_t e = indices.size();
489 int64_t expandedDim = indices[i];
490 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
491 outputShape[expandedDim])) {
493 sliceOp,
"Not a contiguous slice of the expanded tensor.");
575 struct BubbleUpCollapseShapeThroughExtractSlice
579 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
581 auto collapseShapeOp =
582 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
583 if (!collapseShapeOp) {
586 "tensor.extract_slice source not produced by tensor.collapse_shape");
589 if (!sliceOp.hasUnitStride()) {
591 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
592 "be supported in this transformation.");
603 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
604 collapsedSizes.size()) {
606 "unimplemented: rank reducing slice");
611 collapseShapeOp.getReassociationIndices();
621 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
622 llvm::zip_equal(collapsedSizes, collapsedOffsets,
623 collapseShapeOp.getReassociationIndices())) {
628 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
629 int nonUnitSizeCount = 0;
630 for (int64_t expandedShapeIdx : reassocIndices) {
631 if (srcShape[expandedShapeIdx] != 1) {
633 expandedSizes.push_back(collapsedSize);
634 expandedOffsets.push_back(collapsedOffset);
642 if (nonUnitSizeCount != 1) {
645 "unsupported: slice cannot be verified to be contiguous");
667 int64_t currentCollapsedOffset =
673 reassocIndices.rend());
675 int64_t reassocGroupSize = reassocIndices.size();
679 for (; idx < reassocGroupSize; ++idx) {
680 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
682 if (currentCollapsedsize < expandedShapeSize)
687 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
688 (currentCollapsedOffset % expandedShapeSize) != 0) {
690 sliceOp,
"unsupported: cannot be extracted as a contiguous slice "
691 "of the src of the collapse_shape");
694 groupExpandedSizes.push_back(rewriter.
getIndexAttr(expandedShapeSize));
695 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(0));
697 currentCollapsedsize /= expandedShapeSize;
698 currentCollapsedOffset /= expandedShapeSize;
702 if (idx < reassocGroupSize) {
703 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
704 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
707 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
709 sliceOp,
"unsupported: slice cannot be extracted as a contiguous "
710 "slice of the src of the collapse_shape");
713 groupExpandedSizes.push_back(
715 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(offsetInDim));
717 currentCollapsedOffset /= expandedShapeSize;
727 for (idx++; idx < reassocGroupSize; ++idx) {
728 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
729 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
731 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(offsetInDim));
732 currentCollapsedOffset /= expandedShapeSize;
735 expandedSizes.append(groupExpandedSizes.rbegin(),
736 groupExpandedSizes.rend());
737 expandedOffsets.append(groupExpandedOffsets.rbegin(),
738 groupExpandedOffsets.rend());
741 Value newSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
742 collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
743 expandedSizes, expandedStrides);
745 sliceOp, sliceOp.getResultType(), newSliceOp,
746 collapseShapeOp.getReassociationIndices());
757 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
758 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
759 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
760 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
761 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
772 patterns.add<BubbleUpExpandShapeThroughExtractSlice,
773 BubbleUpCollapseShapeThroughExtractSlice>(
patterns.getContext());
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...