41 auto srcType = op.getSourceVectorType();
42 auto dstType = op.getDestVectorType();
44 if (op.getOffsets().getValue().empty())
47 auto loc = op.getLoc();
48 int64_t rankDiff = dstType.getRank() - srcType.getRank();
49 assert(rankDiff >= 0);
53 int64_t rankRest = dstType.getRank() - rankDiff;
63 auto stridedSliceInnerOp = rewriter.
create<InsertStridedSliceOp>(
64 loc, op.getSource(), extracted,
69 op, stridedSliceInnerOp.getResult(), op.getDest(),
92 setHasBoundedRewriteRecursion();
97 auto srcType = op.getSourceVectorType();
98 auto dstType = op.getDestVectorType();
99 int64_t srcRank = srcType.getRank();
102 if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
105 if (op.getOffsets().getValue().empty())
108 int64_t dstRank = dstType.getRank();
109 assert(dstRank >= srcRank);
110 if (dstRank != srcRank)
113 if (srcType == dstType) {
119 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
120 int64_t size = srcType.getShape().front();
122 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
124 auto loc = op.getLoc();
125 Value res = op.getDest();
128 int nSrc = srcType.getShape().front();
129 int nDest = dstType.getShape().front();
132 for (int64_t i = 0; i < nSrc; ++i)
134 Value scaledSource = rewriter.
create<ShuffleOp>(loc, op.getSource(),
135 op.getSource(), offsets);
140 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
141 if (i < offset || i >= e || (i - offset) % stride != 0)
142 offsets.push_back(nDest + i);
144 offsets.push_back((i - offset) / stride);
155 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
156 off += stride, ++idx) {
158 Value extractedSource =
159 rewriter.
create<ExtractOp>(loc, op.getSource(), idx);
160 if (isa<VectorType>(extractedSource.
getType())) {
163 Value extractedDest =
164 rewriter.
create<ExtractOp>(loc, op.getDest(), off);
167 extractedSource = rewriter.
create<InsertStridedSliceOp>(
168 loc, extractedSource, extractedDest,
173 res = rewriter.
create<InsertOp>(loc, extractedSource, res, off);
190 auto dstType = op.getType();
191 auto srcType = op.getSourceVectorType();
194 if (dstType.isScalable() || srcType.isScalable())
197 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
200 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
201 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
203 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
205 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
208 if (op.getOffsets().getValue().size() != 1)
212 offsets.reserve(size);
213 for (int64_t off = offset, e = offset + size * stride; off < e;
215 offsets.push_back(off);
217 op.getVector(), offsets);
230 std::function<
bool(ExtractStridedSliceOp)> controlFn,
236 if (controlFn && !controlFn(op))
240 if (op.getOffsets().getValue().size() != 1)
244 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
245 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
247 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
251 elements.reserve(size);
252 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
253 elements.push_back(rewriter.
create<ExtractOp>(loc, op.getVector(), i));
257 for (int64_t i = 0; i < size; ++i)
258 result = rewriter.
create<InsertOp>(loc, elements[i], result, i);
265 std::function<bool(ExtractStridedSliceOp)> controlFn;
279 setHasBoundedRewriteRecursion();
284 auto dstType = op.getType();
286 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
289 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
290 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
292 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
294 auto loc = op.getLoc();
295 auto elemType = dstType.getElementType();
296 assert(elemType.isSignlessIntOrIndexOrFloat());
300 if (op.getOffsets().getValue().size() == 1)
306 Value res = rewriter.
create<SplatOp>(loc, dstType, zero);
307 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308 off += stride, ++idx) {
309 Value one = rewriter.
create<ExtractOp>(loc, op.getVector(), off);
310 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
314 res = rewriter.
create<InsertOp>(loc, extracted, res, idx);
331 std::function<
bool(ExtractStridedSliceOp)> controlFn,
334 patterns.getContext(), std::move(controlFn), benefit);
350 [](ExtractStridedSliceOp op) {
351 return op.getType().isScalable() ||
352 op.getSourceVectorType().isScalable();
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...