23 Value into, int64_t offset) {
24 auto vectorType = cast<VectorType>(into.
getType());
25 if (vectorType.getRank() > 1)
26 return rewriter.
create<InsertOp>(loc, from, into, offset);
27 return rewriter.
create<vector::InsertElementOp>(
28 loc, vectorType, from, into,
29 rewriter.
create<arith::ConstantIndexOp>(loc, offset));
35 auto vectorType = cast<VectorType>(vector.
getType());
36 if (vectorType.getRank() > 1)
37 return rewriter.
create<ExtractOp>(loc, vector, offset);
38 return rewriter.
create<vector::ExtractElementOp>(
39 loc, vectorType.getElementType(), vector,
40 rewriter.
create<arith::ConstantIndexOp>(loc, offset));
63 auto srcType = op.getSourceVectorType();
64 auto dstType = op.getDestVectorType();
66 if (op.getOffsets().getValue().empty())
70 int64_t rankDiff = dstType.getRank() - srcType.getRank();
71 assert(rankDiff >= 0);
75 int64_t rankRest = dstType.getRank() - rankDiff;
85 auto stridedSliceInnerOp = rewriter.
create<InsertStridedSliceOp>(
86 loc, op.getSource(), extracted,
91 op, stridedSliceInnerOp.
getResult(), op.getDest(),
114 setHasBoundedRewriteRecursion();
119 auto srcType = op.getSourceVectorType();
120 auto dstType = op.getDestVectorType();
122 if (op.getOffsets().getValue().empty())
125 int64_t srcRank = srcType.getRank();
126 int64_t dstRank = dstType.getRank();
127 assert(dstRank >= srcRank);
128 if (dstRank != srcRank)
131 if (srcType == dstType) {
137 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
138 int64_t size = srcType.getShape().front();
140 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
143 Value res = op.getDest();
146 int nSrc = srcType.getShape().front();
147 int nDest = dstType.getShape().front();
150 for (int64_t i = 0; i < nSrc; ++i)
152 Value scaledSource = rewriter.
create<ShuffleOp>(loc, op.getSource(),
153 op.getSource(), offsets);
158 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159 if (i < offset || i >= e || (i - offset) % stride != 0)
160 offsets.push_back(nDest + i);
162 offsets.push_back((i - offset) / stride);
173 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174 off += stride, ++idx) {
176 Value extractedSource =
extractOne(rewriter, loc, op.getSource(), idx);
177 if (isa<VectorType>(extractedSource.
getType())) {
183 extractedSource = rewriter.
create<InsertStridedSliceOp>(
184 loc, extractedSource, extractedDest,
189 res =
insertOne(rewriter, loc, extractedSource, res, off);
206 auto dstType = op.getType();
208 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
211 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
212 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
214 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
219 if (op.getOffsets().getValue().size() != 1)
223 offsets.reserve(size);
224 for (int64_t off = offset, e = offset + size * stride; off < e;
226 offsets.push_back(off);
228 op.getVector(), offsets);
241 std::function<
bool(ExtractStridedSliceOp)> controlFn,
247 if (controlFn && !controlFn(op))
251 if (op.getOffsets().getValue().size() != 1)
255 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
256 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
258 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
262 elements.reserve(size);
263 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
264 elements.push_back(rewriter.
create<ExtractOp>(loc, op.getVector(), i));
268 for (int64_t i = 0; i < size; ++i)
269 result = rewriter.
create<InsertOp>(loc, elements[i], result, i);
276 std::function<bool(ExtractStridedSliceOp)> controlFn;
290 setHasBoundedRewriteRecursion();
295 auto dstType = op.getType();
297 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
300 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
301 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
303 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
306 auto elemType = dstType.getElementType();
307 assert(elemType.isSignlessIntOrIndexOrFloat());
311 if (op.getOffsets().getValue().size() == 1)
317 Value res = rewriter.
create<SplatOp>(loc, dstType, zero);
318 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
319 off += stride, ++idx) {
321 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
325 res =
insertOne(rewriter, loc, extracted, res, idx);
340 std::function<
bool(ExtractStridedSliceOp)> controlFn,
343 patterns.
getContext(), std::move(controlFn), benefit);
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...