12 #include "llvm/Support/Debug.h"
19 struct FoldExpandOfRankReducingExtract
23 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
25 RankedTensorType resultType = expandShapeOp.getResultType();
27 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
30 RankedTensorType srcType = extractSliceOp.getSourceType();
35 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36 srcType, extractSliceOp.getStaticOffsets(),
37 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38 if (nonReducingExtractType != resultType)
45 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
53 struct FoldUnPaddingCollapseIntoExtract
57 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
60 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
64 if (!extractSliceOp || !extractSliceOp->hasOneUse())
70 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
73 "expected unpadding collapse");
75 Value unPaddedExtractSlice = rewriter.
create<tensor::ExtractSliceOp>(
76 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
77 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
78 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
79 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
85 template <
typename OpTy>
89 LogicalResult matchAndRewrite(OpTy insertSliceOp,
91 auto collapseShapeOp =
92 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
95 RankedTensorType srcType = collapseShapeOp.getSrcType();
100 RankedTensorType nonReducingInsertType =
102 insertSliceOp.getDestType().getElementType());
103 if (nonReducingInsertType != srcType)
110 insertSliceOp.getDest(), mixedOffsets,
111 mixedSizes, mixedStrides);
118 template <
typename OpTy>
122 LogicalResult matchAndRewrite(OpTy insertSliceOp,
124 auto expandShapeOp = insertSliceOp.getSource()
125 .template getDefiningOp<tensor::ExpandShapeOp>();
132 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
135 "expected rank increasing expansion");
138 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
146 struct BubbleUpExpandThroughParallelCollapse
150 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
153 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
156 auto expandReInds = expandOp.getReassociationIndices();
157 auto collapseReInds = collapseOp.getReassociationIndices();
161 for (
auto [expandReassociation, collapseReassociation] :
162 llvm::zip_equal(expandReInds, collapseReInds)) {
163 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
173 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
175 int64_t index = 0, expandIndex = 0, collapseIndex = 0;
176 for (
auto [idx, collapseReassociation] :
llvm::enumerate(collapseReInds)) {
177 if (collapseReassociation.size() != 1) {
179 for (
size_t i = 0; i < collapseReassociation.size(); ++i) {
180 newCollapseReassociation.push_back(index);
181 newExpandReInds.push_back({index++});
182 newExpandSizes.push_back(collapseSizes[collapseIndex++]);
184 newCollapseReInds.push_back(newCollapseReassociation);
189 auto expandReassociation = expandReInds[idx];
190 for (
size_t i = 0; i < expandReassociation.size(); ++i) {
191 newExpandReassociation.push_back(index);
192 newCollapseReInds.push_back({index++});
193 newExpandSizes.push_back(expandSizes[expandIndex++]);
195 newExpandReInds.push_back(newExpandReassociation);
203 auto expandResultType = expandOp.getResultType().clone(staticSizes);
204 auto newExpand = rewriter.
create<tensor::ExpandShapeOp>(
205 loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
208 expandOp, newExpand.getResult(), newCollapseReInds);
218 .
add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
219 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
220 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
221 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
222 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
228 patterns.
add<BubbleUpExpandThroughParallelCollapse>(patterns.
getContext());
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...