23 Value into, int64_t offset) {
24 auto vectorType = into.
getType().
cast<VectorType>();
25 if (vectorType.getRank() > 1)
26 return rewriter.
create<InsertOp>(loc, from, into, offset);
27 return rewriter.
create<vector::InsertElementOp>(
28 loc, vectorType, from, into,
29 rewriter.
create<arith::ConstantIndexOp>(loc, offset));
35 auto vectorType = vector.
getType().
cast<VectorType>();
36 if (vectorType.getRank() > 1)
37 return rewriter.
create<ExtractOp>(loc, vector, offset);
38 return rewriter.
create<vector::ExtractElementOp>(
39 loc, vectorType.getElementType(), vector,
40 rewriter.
create<arith::ConstantIndexOp>(loc, offset));
63 auto srcType = op.getSourceVectorType();
64 auto dstType = op.getDestVectorType();
66 if (op.getOffsets().getValue().empty())
69 auto loc = op.getLoc();
70 int64_t rankDiff = dstType.getRank() - srcType.getRank();
71 assert(rankDiff >= 0);
75 int64_t rankRest = dstType.getRank() - rankDiff;
85 auto stridedSliceInnerOp = rewriter.
create<InsertStridedSliceOp>(
86 loc, op.getSource(), extracted,
91 op, stridedSliceInnerOp.getResult(), op.getDest(),
114 setHasBoundedRewriteRecursion();
119 auto srcType = op.getSourceVectorType();
120 auto dstType = op.getDestVectorType();
122 if (op.getOffsets().getValue().empty())
125 int64_t srcRank = srcType.getRank();
126 int64_t dstRank = dstType.getRank();
127 assert(dstRank >= srcRank);
128 if (dstRank != srcRank)
131 if (srcType == dstType) {
137 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
138 int64_t size = srcType.getShape().front();
140 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
142 auto loc = op.getLoc();
143 Value res = op.getDest();
146 int nSrc = srcType.getShape().front();
147 int nDest = dstType.getShape().front();
150 for (int64_t i = 0; i < nSrc; ++i)
152 Value scaledSource = rewriter.
create<ShuffleOp>(loc, op.getSource(),
153 op.getSource(), offsets);
158 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159 if (i < offset || i >= e || (i - offset) % stride != 0)
160 offsets.push_back(nDest + i);
162 offsets.push_back((i - offset) / stride);
173 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174 off += stride, ++idx) {
176 Value extractedSource =
extractOne(rewriter, loc, op.getSource(), idx);
177 if (extractedSource.
getType().
isa<VectorType>()) {
183 extractedSource = rewriter.
create<InsertStridedSliceOp>(
184 loc, extractedSource, extractedDest,
189 res =
insertOne(rewriter, loc, extractedSource, res, off);
206 auto dstType = op.getType();
208 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
211 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
213 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
215 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
217 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
220 if (op.getOffsets().getValue().size() != 1)
224 offsets.reserve(size);
225 for (int64_t off = offset, e = offset + size * stride; off < e;
227 offsets.push_back(off);
243 std::function<
bool(ExtractStridedSliceOp)> controlFn,
249 if (controlFn && !controlFn(op))
253 if (op.getOffsets().getValue().size() != 1)
257 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
259 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
261 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
265 elements.reserve(size);
266 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
267 elements.push_back(rewriter.
create<ExtractOp>(loc, op.getVector(), i));
271 for (int64_t i = 0; i < size; ++i)
272 result = rewriter.
create<InsertOp>(loc, elements[i], result, i);
279 std::function<bool(ExtractStridedSliceOp)> controlFn;
293 setHasBoundedRewriteRecursion();
298 auto dstType = op.getType();
300 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
303 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
305 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
307 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
309 auto loc = op.getLoc();
310 auto elemType = dstType.getElementType();
311 assert(elemType.isSignlessIntOrIndexOrFloat());
315 if (op.getOffsets().getValue().size() == 1)
321 Value res = rewriter.
create<SplatOp>(loc, dstType, zero);
322 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
323 off += stride, ++idx) {
325 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
329 res =
insertOne(rewriter, loc, extracted, res, idx);
344 std::function<
bool(ExtractStridedSliceOp)> controlFn,
347 patterns.
getContext(), std::move(controlFn), benefit);
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
Attribute getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...