15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/LogicalResult.h"
23 struct FoldExpandOfRankReducingExtract
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29 RankedTensorType resultType = expandShapeOp.getResultType();
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticOffsets(),
41 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42 if (nonReducingExtractType != resultType)
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
57 struct FoldUnPaddingCollapseIntoExtract
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
77 "expected unpadding collapse");
79 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
89 template <
typename OpTy>
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
104 RankedTensorType nonReducingInsertType =
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
122 template <
typename OpTy>
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
139 "expected rank increasing expansion");
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
150 struct BubbleUpExpandThroughParallelCollapse
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
165 if (expandReInds.size() == 0) {
176 for (
auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(expandReInds, collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
185 collapseReassociation.front(), collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
189 expandReassociation.front(), expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
215 for (
size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back({newCollapseIndex++});
228 newExpandReInds.push_back({newExpandIndex++});
229 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
238 if (collapseReassociation.size() != 1) {
240 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(newCollapseIndex++);
242 newExpandReInds.push_back({newExpandIndex++});
243 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
246 newCollapseReInds.push_back(newCollapseReassociation);
254 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(newExpandIndex++);
256 newCollapseReInds.push_back({newCollapseIndex++});
257 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
259 newExpandReInds.push_back(newExpandReassociation);
267 auto expandResultType = expandOp.getResultType().clone(staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = tensor::ExpandShapeOp::create(
274 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
281 Value replacement = newCollapseSrc;
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
283 replacement = tensor::CollapseShapeOp::create(
284 rewriter, loc, newCollapseSrc, newCollapseReInds);
286 rewriter.
replaceOp(expandOp, replacement);
322 struct BubbleUpExpandShapeThroughExtractSlice
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
331 if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
345 expandShapeOp.getOutputShape(), rewriter);
348 Location loc = expandShapeOp->getLoc();
365 expandShapeOp.getReassociationIndices()) {
374 for (
long expandedDim : indices) {
378 reassocGroupSizes.push_back(expandedShape[expandedDim]);
379 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
380 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
384 llvm::map_to_vector(reassocGroupOffsets, [&](
OpFoldResult ofr) {
389 .
create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
393 collapsedOffsets.push_back(collapsedOffset);
394 collapsedSizes.push_back(collapsedSize);
405 shape, expandShapeOp.getResultType().getElementType());
408 Value newSliceOp = tensor::ExtractSliceOp::create(
409 rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
412 sliceOp, resultType, newSliceOp,
413 expandShapeOp.getReassociationIndices(), expandedSizes);
421 checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
422 tensor::ExpandShapeOp expandShapeOp,
425 if (!expandShapeOp) {
427 sliceOp,
"tensor.extract_slice source not produced by expand_shape");
430 if (!sliceOp.hasUnitStride()) {
432 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
433 "be supported in this transformation.");
439 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
442 "unimplemented: rank reducing slice");
447 expandShapeOp.getOutputShape(), rewriter);
450 isZeroOffsetAndFullSize =
454 FailureOr<bool> maybeEqual =
456 return llvm::succeeded(maybeEqual) && maybeEqual.value();
470 expandShapeOp.getReassociationIndices()) {
472 int64_t e = indices.size();
486 int64_t expandedDim = indices[i];
487 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
488 outputShape[expandedDim])) {
490 sliceOp,
"Not a contiguous slice of the expanded tensor.");
572 struct BubbleUpCollapseShapeThroughExtractSlice
576 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
578 auto collapseShapeOp =
579 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
580 if (!collapseShapeOp) {
583 "tensor.extract_slice source not produced by tensor.collapse_shape");
586 if (!sliceOp.hasUnitStride()) {
588 sliceOp,
"unsupported: non-unit stride. Only contiguous slices can "
589 "be supported in this transformation.");
600 if (
static_cast<size_t>(sliceOp.getResultType().getRank()) !=
601 collapsedSizes.size()) {
603 "unimplemented: rank reducing slice");
608 collapseShapeOp.getReassociationIndices();
618 for (
auto [collapsedSize, collapsedOffset, reassocIndices] :
619 llvm::zip_equal(collapsedSizes, collapsedOffsets,
620 collapseShapeOp.getReassociationIndices())) {
625 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
626 int nonUnitSizeCount = 0;
627 for (int64_t expandedShapeIdx : reassocIndices) {
628 if (srcShape[expandedShapeIdx] != 1) {
630 expandedSizes.push_back(collapsedSize);
631 expandedOffsets.push_back(collapsedOffset);
639 if (nonUnitSizeCount != 1) {
642 "unsupported: slice cannot be verified to be contiguous");
664 int64_t currentCollapsedOffset =
670 reassocIndices.rend());
672 int64_t reassocGroupSize = reassocIndices.size();
676 for (; idx < reassocGroupSize; ++idx) {
677 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
679 if (currentCollapsedsize < expandedShapeSize)
684 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
685 (currentCollapsedOffset % expandedShapeSize) != 0) {
687 sliceOp,
"unsupported: cannot be extracted as a contiguous slice "
688 "of the src of the collapse_shape");
691 groupExpandedSizes.push_back(rewriter.
getIndexAttr(expandedShapeSize));
692 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(0));
694 currentCollapsedsize /= expandedShapeSize;
695 currentCollapsedOffset /= expandedShapeSize;
699 if (idx < reassocGroupSize) {
700 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
701 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
704 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
706 sliceOp,
"unsupported: slice cannot be extracted as a contiguous "
707 "slice of the src of the collapse_shape");
710 groupExpandedSizes.push_back(
712 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(offsetInDim));
714 currentCollapsedOffset /= expandedShapeSize;
724 for (idx++; idx < reassocGroupSize; ++idx) {
725 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
726 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
728 groupExpandedOffsets.push_back(rewriter.
getIndexAttr(offsetInDim));
729 currentCollapsedOffset /= expandedShapeSize;
732 expandedSizes.append(groupExpandedSizes.rbegin(),
733 groupExpandedSizes.rend());
734 expandedOffsets.append(groupExpandedOffsets.rbegin(),
735 groupExpandedOffsets.rend());
738 Value newSliceOp = tensor::ExtractSliceOp::create(
739 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
740 expandedOffsets, expandedSizes, expandedStrides);
742 sliceOp, sliceOp.getResultType(), newSliceOp,
743 collapseShapeOp.getReassociationIndices());
754 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
755 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
756 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
757 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
758 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
769 patterns.add<BubbleUpExpandShapeThroughExtractSlice,
770 BubbleUpCollapseShapeThroughExtractSlice>(
patterns.getContext());
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...