15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/LogicalResult.h"
24 struct FoldExpandOfRankReducingExtract
28 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
30 RankedTensorType resultType = expandShapeOp.getResultType();
32 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
35 RankedTensorType srcType = extractSliceOp.getSourceType();
40 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41 srcType, extractSliceOp.getStaticOffsets(),
42 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
43 if (nonReducingExtractType != resultType)
50 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
58 struct FoldUnPaddingCollapseIntoExtract
62 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
65 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
69 if (!extractSliceOp || !extractSliceOp->hasOneUse())
75 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
78 "expected unpadding collapse");
80 Value unPaddedExtractSlice = rewriter.
create<tensor::ExtractSliceOp>(
81 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
82 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
83 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
84 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
90 template <
typename OpTy>
94 LogicalResult matchAndRewrite(OpTy insertSliceOp,
96 auto collapseShapeOp =
97 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
100 RankedTensorType srcType = collapseShapeOp.getSrcType();
105 RankedTensorType nonReducingInsertType =
107 insertSliceOp.getDestType().getElementType());
108 if (nonReducingInsertType != srcType)
115 insertSliceOp.getDest(), mixedOffsets,
116 mixedSizes, mixedStrides);
123 template <
typename OpTy>
127 LogicalResult matchAndRewrite(OpTy insertSliceOp,
129 auto expandShapeOp = insertSliceOp.getSource()
130 .template getDefiningOp<tensor::ExpandShapeOp>();
137 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
140 "expected rank increasing expansion");
143 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
151 struct BubbleUpExpandThroughParallelCollapse
155 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
158 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
161 auto expandReInds = expandOp.getReassociationIndices();
162 auto collapseReInds = collapseOp.getReassociationIndices();
166 if (expandReInds.size() == 0) {
177 for (
auto [expandReassociation, collapseReassociation] :
178 llvm::zip_equal(expandReInds, collapseReInds)) {
179 if (collapseReassociation.size() == expandReassociation.size()) {
186 collapseReassociation.front(), collapseReassociation.size());
187 int64_t numCollapsedDynamic =
188 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
190 expandReassociation.front(), expandReassociation.size());
191 int64_t numExpandedDynamic =
192 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
193 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
194 collapsedStaticShapes != expandedStaticShapes) {
201 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
213 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
216 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
217 auto &collapseReassociation = collapseReInds[idx];
218 auto &expandReassociation = expandReInds[idx];
226 if (collapseReassociation.size() == expandReassociation.size()) {
227 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
228 newCollapseReInds.push_back({newCollapseIndex++});
229 newExpandReInds.push_back({newExpandIndex++});
230 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
239 if (collapseReassociation.size() != 1) {
241 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
242 newCollapseReassociation.push_back(newCollapseIndex++);
243 newExpandReInds.push_back({newExpandIndex++});
244 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
247 newCollapseReInds.push_back(newCollapseReassociation);
255 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
256 newExpandReassociation.push_back(newExpandIndex++);
257 newCollapseReInds.push_back({newCollapseIndex++});
258 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
260 newExpandReInds.push_back(newExpandReassociation);
268 auto expandResultType = expandOp.getResultType().clone(staticSizes);
269 Value newCollapseSrc = collapseOp.getSrc();
273 if (newExpandReInds.size() != newExpandSizes.size()) {
274 newCollapseSrc = rewriter.
create<tensor::ExpandShapeOp>(
275 loc, expandResultType, newCollapseSrc, newExpandReInds,
282 Value replacement = newCollapseSrc;
283 if (newCollapseReInds.size() != newExpandSizes.size()) {
284 replacement = rewriter.
create<tensor::CollapseShapeOp>(
285 loc, newCollapseSrc, newCollapseReInds);
287 rewriter.
replaceOp(expandOp, replacement);
323 struct BubbleUpExpandShapeThroughExtractSlice
327 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
330 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
332 if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
346 expandShapeOp.getOutputShape(), rewriter);
349 Location loc = expandShapeOp->getLoc();
366 expandShapeOp.getReassociationIndices()) {
375 for (
long expandedDim : indices) {
379 reassocGroupSizes.push_back(expandedShape[expandedDim]);
380 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
381 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
385 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
390 .
create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
394 collapsedOffsets.push_back(collapsedOffset);
395 collapsedSizes.push_back(collapsedSize);
406 shape, expandShapeOp.getResultType().getElementType());
409 Value newSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
410 loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
413 sliceOp, resultType, newSliceOp,
414 expandShapeOp.getReassociationIndices(), expandedSizes);
422 checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
423 tensor::ExpandShapeOp expandShapeOp,
426 if (!expandShapeOp) {
428 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
431 if (!sliceOp.hasUnitStride()) {
433 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
434 "be supported in this transformation.");
440 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
443 "unimplemented: rank reducing slice");
448 expandShapeOp.getOutputShape(), rewriter);
451 isZeroOffsetAndFullSize =
455 FailureOr<bool> maybeEqual =
457 return llvm::succeeded(maybeEqual) && maybeEqual.value();
471 expandShapeOp.getReassociationIndices()) {
473 int64_t e = indices.size();
487 int64_t expandedDim = indices[i];
488 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
489 outputShape[expandedDim])) {
491 sliceOp,
"Not a contiguous slice of the expanded tensor.");
573 struct BubbleUpCollapseShapeThroughExtractSlice
577 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
579 auto collapseShapeOp =
580 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
581 if (!collapseShapeOp) {
584 "tensor.extract_slice source not produced by tensor.collapse_shape");
587 if (!sliceOp.hasUnitStride()) {
589 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
590 "be supported in this transformation.");
601 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602 collapsedSizes.size()) {
604 "unimplemented: rank reducing slice");
609 collapseShapeOp.getReassociationIndices();
619 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
620 llvm::zip_equal(collapsedSizes, collapsedOffsets,
621 collapseShapeOp.getReassociationIndices())) {
626 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
627 int nonUnitSizeCount = 0;
628 for (int64_t expandedShapeIdx : reassocIndices) {
629 if (srcShape[expandedShapeIdx] != 1) {
631 expandedSizes.push_back(collapsedSize);
632 expandedOffsets.push_back(collapsedOffset);
640 if (nonUnitSizeCount != 1) {
643 "unsupported: slice cannot be verified to be contiguous");
665 int64_t currentCollapsedOffset =
671 reassocIndices.rend());
673 int64_t reassocGroupSize = reassocIndices.size();
677 for (; idx < reassocGroupSize; ++idx) {
678 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
680 if (currentCollapsedsize < expandedShapeSize)
685 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
686 (currentCollapsedOffset % expandedShapeSize) != 0) {
688 sliceOp,
"unsupported: cannot be extracted as a contiguous slice "
689 "of the src of the collapse_shape");
692 groupExpandedSizes.push_back(rewriter.
getIndexAttr(expandedShapeSize));
693 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(0));
695 currentCollapsedsize /= expandedShapeSize;
696 currentCollapsedOffset /= expandedShapeSize;
700 if (idx < reassocGroupSize) {
701 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
702 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
705 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
707 sliceOp,
"unsupported: slice cannot be extracted as a contiguous "
708 "slice of the src of the collapse_shape");
711 groupExpandedSizes.push_back(
713 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(offsetInDim));
715 currentCollapsedOffset /= expandedShapeSize;
725 for (idx++; idx < reassocGroupSize; ++idx) {
726 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
727 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
729 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(offsetInDim));
730 currentCollapsedOffset /= expandedShapeSize;
733 expandedSizes.append(groupExpandedSizes.rbegin(),
734 groupExpandedSizes.rend());
735 expandedOffsets.append(groupExpandedOffsets.rbegin(),
736 groupExpandedOffsets.rend());
739 Value newSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
740 collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
741 expandedSizes, expandedStrides);
743 sliceOp, sliceOp.getResultType(), newSliceOp,
744 collapseShapeOp.getReassociationIndices());
755 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
756 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
757 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
758 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
759 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
770 patterns.add<BubbleUpExpandShapeThroughExtractSlice,
771 BubbleUpCollapseShapeThroughExtractSlice>(
patterns.getContext());
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...