23 struct MergeConsecutiveExtractSlice :
public OpRewritePattern<ExtractSliceOp> {
28 auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
34 rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35 newOffsets, newSizes, newStrides)))
39 prevOp.getSource(), newOffsets,
40 newSizes, newStrides);
47 template <
typename OpTy>
53 auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
57 if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
70 if (!prevOp.getSourceType().hasStaticShape() ||
71 !prevOp.getDestType().hasStaticShape())
75 nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76 nextOp.getMixedSizes(), nextOp.getMixedStrides());
86 struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
93 llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
94 if (droppedDims.none())
99 extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
102 llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
106 if (expandedDims != droppedDims)
110 if (!insertSliceOp->hasOneUse())
122 for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
124 if (droppedDims.test(i))
126 newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
127 newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
128 newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
131 extractSliceOp, insertSliceOp.getSource(), newOffsets,
132 newSizes, newStrides);
133 rewriter.
eraseOp(insertSliceOp);
154 struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
158 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
160 auto extractSliceOp =
161 insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162 if (!extractSliceOp) {
164 "source is not extract_slice");
168 if (!extractSliceOp->hasOneUse()) {
170 "source has multi-uses");
176 "insert_slice is not cast-like");
179 llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180 llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
182 if (extractDroppedDims.size() < insertDroppedDims.size()) {
184 "insert_slice expands more dims");
191 unsigned insertDimPos = 0;
192 for (
unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
195 if (insertDimPos == insertDroppedDims.size())
198 bool isExtractDropped = extractDroppedDims[extractDimPos];
199 bool isInsertDropped = insertDroppedDims[insertDimPos];
202 if (isExtractDropped == isInsertDropped) {
204 }
else if (!isExtractDropped && isInsertDropped) {
207 "insert_slice drops more unit dims");
214 if (insertDimPos != insertDroppedDims.size()) {
216 "insert_slice has unmatched dims");
220 insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221 extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222 extractSliceOp.getMixedStrides());
223 rewriter.
eraseOp(extractSliceOp);
232 patterns.
add<MergeConsecutiveExtractSlice,
233 MergeConsecutiveInsertSlice<InsertSliceOp>,
234 MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
240 patterns.
add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241 DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...