31struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
35 LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
36 PatternRewriter &rewriter)
const override {
38 llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
39 if (droppedDims.none())
44 extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
47 llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
50 if (!expandedDims.subsetOf(droppedDims))
54 if (!insertSliceOp->hasOneUse())
63 OpBuilder::InsertionGuard g(rewriter);
65 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
66 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
67 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
68 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
69 for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
71 if (expandedDims.test(i))
73 newOffsets.push_back(mixedOffsets[i]);
74 newSizes.push_back(mixedSizes[i]);
75 newStrides.push_back(mixedStrides[i]);
78 extractSliceOp, extractSliceOp.getResultType(),
79 insertSliceOp.getSource(), newOffsets, newSizes, newStrides);
80 rewriter.
eraseOp(insertSliceOp);
101struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
103 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
105 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
106 PatternRewriter &rewriter)
const override {
107 auto extractSliceOp =
108 insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
109 if (!extractSliceOp) {
111 "source is not extract_slice");
115 if (!extractSliceOp->hasOneUse()) {
117 "source has multi-uses");
123 "insert_slice is not cast-like");
126 llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
127 llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
129 if (extractDroppedDims.size() < insertDroppedDims.size()) {
131 "insert_slice expands more dims");
138 unsigned insertDimPos = 0;
139 for (
unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
142 if (insertDimPos == insertDroppedDims.size())
145 bool isExtractDropped = extractDroppedDims[extractDimPos];
146 bool isInsertDropped = insertDroppedDims[insertDimPos];
149 if (isExtractDropped == isInsertDropped) {
151 }
else if (!isExtractDropped && isInsertDropped) {
154 "insert_slice drops more unit dims");
161 if (insertDimPos != insertDroppedDims.size()) {
163 "insert_slice has unmatched dims");
167 insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
168 extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
169 extractSliceOp.getMixedStrides());
170 rewriter.
eraseOp(extractSliceOp);
179 patterns.
add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
180 DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...