23 struct MergeConsecutiveExtractSlice :
public OpRewritePattern<ExtractSliceOp> {
26 LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
28 auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
34 rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
35 newOffsets, newSizes, newStrides)))
39 prevOp.getSource(), newOffsets,
40 newSizes, newStrides);
47 template <
typename OpTy>
51 LogicalResult matchAndRewrite(OpTy nextOp,
53 auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
57 if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
70 if (!prevOp.getSourceType().hasStaticShape() ||
71 !prevOp.getDestType().hasStaticShape())
75 nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
76 nextOp.getMixedSizes(), nextOp.getMixedStrides());
86 struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
90 LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
93 llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
94 if (droppedDims.none())
99 extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
102 llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
106 if (expandedDims != droppedDims)
110 if (!insertSliceOp->hasOneUse())
122 for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
124 if (droppedDims.test(i))
126 newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
127 newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
128 newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
131 extractSliceOp, insertSliceOp.getSource(), newOffsets,
132 newSizes, newStrides);
133 rewriter.
eraseOp(insertSliceOp);
154 struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
158 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
160 auto extractSliceOp =
161 insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
162 if (!extractSliceOp) {
164 "source is not extract_slice");
168 if (!extractSliceOp->hasOneUse()) {
170 "source has multi-uses");
176 "insert_slice is not cast-like");
179 llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
180 llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
182 if (extractDroppedDims.size() < insertDroppedDims.size()) {
184 "insert_slice expands more dims");
191 unsigned insertDimPos = 0;
192 for (
unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
195 if (insertDimPos == insertDroppedDims.size())
198 bool isExtractDropped = extractDroppedDims[extractDimPos];
199 bool isInsertDropped = insertDroppedDims[insertDimPos];
202 if (isExtractDropped == isInsertDropped) {
204 }
else if (!isExtractDropped && isInsertDropped) {
207 "insert_slice drops more unit dims");
214 if (insertDimPos != insertDroppedDims.size()) {
216 "insert_slice has unmatched dims");
220 insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
221 extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
222 extractSliceOp.getMixedStrides());
223 rewriter.
eraseOp(extractSliceOp);
232 patterns.add<MergeConsecutiveExtractSlice,
233 MergeConsecutiveInsertSlice<InsertSliceOp>,
234 MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
240 patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
241 DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
const FrozenRewritePatternSet & patterns
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...