12 #include "llvm/Support/Debug.h"
19 struct FoldExpandOfRankReducingExtract
23 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
25 RankedTensorType resultType = expandShapeOp.getResultType();
27 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
30 RankedTensorType srcType = extractSliceOp.getSourceType();
35 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36 srcType, extractSliceOp.getStaticOffsets(),
37 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38 if (nonReducingExtractType != resultType)
45 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
53 struct FoldUnPaddingCollapseIntoExtract
57 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
60 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
64 if (!extractSliceOp || !extractSliceOp->hasOneUse())
70 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
73 "expected unpadding collapse");
75 Value unPaddedExtractSlice = rewriter.
create<tensor::ExtractSliceOp>(
76 extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
77 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
78 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
79 rewriter.
replaceOp(collapseShapeOp, unPaddedExtractSlice);
85 template <
typename OpTy>
89 LogicalResult matchAndRewrite(OpTy insertSliceOp,
91 auto collapseShapeOp =
92 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
95 RankedTensorType srcType = collapseShapeOp.getSrcType();
100 RankedTensorType nonReducingInsertType =
102 insertSliceOp.getDestType().getElementType());
103 if (nonReducingInsertType != srcType)
110 insertSliceOp.getDest(), mixedOffsets,
111 mixedSizes, mixedStrides);
118 template <
typename OpTy>
122 LogicalResult matchAndRewrite(OpTy insertSliceOp,
124 auto expandShapeOp = insertSliceOp.getSource()
125 .template getDefiningOp<tensor::ExpandShapeOp>();
132 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
135 "expected rank increasing expansion");
138 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
148 .
add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
149 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
150 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
151 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
152 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...