41 auto srcType = op.getSourceVectorType();
42 auto dstType = op.getDestVectorType();
44 if (op.getOffsets().getValue().empty())
47 auto loc = op.getLoc();
48 int64_t rankDiff = dstType.getRank() - srcType.getRank();
49 assert(rankDiff >= 0);
53 int64_t rankRest = dstType.getRank() - rankDiff;
63 auto stridedSliceInnerOp = rewriter.
create<InsertStridedSliceOp>(
64 loc, op.getSource(), extracted,
69 op, stridedSliceInnerOp.getResult(), op.getDest(),
92 setHasBoundedRewriteRecursion();
97 auto srcType = op.getSourceVectorType();
98 auto dstType = op.getDestVectorType();
100 if (op.getOffsets().getValue().empty())
103 int64_t srcRank = srcType.getRank();
104 int64_t dstRank = dstType.getRank();
105 assert(dstRank >= srcRank);
106 if (dstRank != srcRank)
109 if (srcType == dstType) {
115 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
116 int64_t size = srcType.getShape().front();
118 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
120 auto loc = op.getLoc();
121 Value res = op.getDest();
124 int nSrc = srcType.getShape().front();
125 int nDest = dstType.getShape().front();
128 for (int64_t i = 0; i < nSrc; ++i)
130 Value scaledSource = rewriter.
create<ShuffleOp>(loc, op.getSource(),
131 op.getSource(), offsets);
136 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
137 if (i < offset || i >= e || (i - offset) % stride != 0)
138 offsets.push_back(nDest + i);
140 offsets.push_back((i - offset) / stride);
151 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
152 off += stride, ++idx) {
154 Value extractedSource =
155 rewriter.
create<ExtractOp>(loc, op.getSource(), idx);
156 if (isa<VectorType>(extractedSource.
getType())) {
159 Value extractedDest =
160 rewriter.
create<ExtractOp>(loc, op.getDest(), off);
163 extractedSource = rewriter.
create<InsertStridedSliceOp>(
164 loc, extractedSource, extractedDest,
169 res = rewriter.
create<InsertOp>(loc, extractedSource, res, off);
186 auto dstType = op.getType();
188 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
191 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
192 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
194 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
196 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
199 if (op.getOffsets().getValue().size() != 1)
203 offsets.reserve(size);
204 for (int64_t off = offset, e = offset + size * stride; off < e;
206 offsets.push_back(off);
208 op.getVector(), offsets);
221 std::function<
bool(ExtractStridedSliceOp)> controlFn,
227 if (controlFn && !controlFn(op))
231 if (op.getOffsets().getValue().size() != 1)
235 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
236 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
238 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
242 elements.reserve(size);
243 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
244 elements.push_back(rewriter.
create<ExtractOp>(loc, op.getVector(), i));
248 for (int64_t i = 0; i < size; ++i)
249 result = rewriter.
create<InsertOp>(loc, elements[i], result, i);
256 std::function<bool(ExtractStridedSliceOp)> controlFn;
270 setHasBoundedRewriteRecursion();
275 auto dstType = op.getType();
277 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
280 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
281 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
283 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
285 auto loc = op.getLoc();
286 auto elemType = dstType.getElementType();
287 assert(elemType.isSignlessIntOrIndexOrFloat());
291 if (op.getOffsets().getValue().size() == 1)
297 Value res = rewriter.
create<SplatOp>(loc, dstType, zero);
298 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
299 off += stride, ++idx) {
300 Value one = rewriter.
create<ExtractOp>(loc, op.getVector(), off);
301 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
305 res = rewriter.
create<InsertOp>(loc, extracted, res, idx);
320 std::function<
bool(ExtractStridedSliceOp)> controlFn,
323 patterns.getContext(), std::move(controlFn), benefit);
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...