23 Value into, int64_t offset) {
24 auto vectorType = cast<VectorType>(into.
getType());
25 if (vectorType.getRank() > 1)
26 return rewriter.
create<InsertOp>(loc, from, into, offset);
27 return rewriter.
create<vector::InsertElementOp>(
28 loc, vectorType, from, into,
29 rewriter.
create<arith::ConstantIndexOp>(loc, offset));
35 auto vectorType = cast<VectorType>(vector.
getType());
36 if (vectorType.getRank() > 1)
37 return rewriter.
create<ExtractOp>(loc, vector, offset);
38 return rewriter.
create<vector::ExtractElementOp>(
39 loc, vectorType.getElementType(), vector,
40 rewriter.
create<arith::ConstantIndexOp>(loc, offset));
63 auto srcType = op.getSourceVectorType();
64 auto dstType = op.getDestVectorType();
66 if (op.getOffsets().getValue().empty())
70 int64_t rankDiff = dstType.getRank() - srcType.getRank();
71 assert(rankDiff >= 0);
75 int64_t rankRest = dstType.getRank() - rankDiff;
85 auto stridedSliceInnerOp = rewriter.
create<InsertStridedSliceOp>(
86 loc, op.getSource(), extracted,
91 op, stridedSliceInnerOp.
getResult(), op.getDest(),
114 setHasBoundedRewriteRecursion();
119 auto srcType = op.getSourceVectorType();
120 auto dstType = op.getDestVectorType();
122 if (op.getOffsets().getValue().empty())
125 int64_t srcRank = srcType.getRank();
126 int64_t dstRank = dstType.getRank();
127 assert(dstRank >= srcRank);
128 if (dstRank != srcRank)
131 if (srcType == dstType) {
137 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
138 int64_t size = srcType.getShape().front();
140 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
143 Value res = op.getDest();
146 int nSrc = srcType.getShape().front();
147 int nDest = dstType.getShape().front();
150 for (int64_t i = 0; i < nSrc; ++i)
152 Value scaledSource = rewriter.
create<ShuffleOp>(loc, op.getSource(),
153 op.getSource(), offsets);
158 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159 if (i < offset || i >= e || (i - offset) % stride != 0)
160 offsets.push_back(nDest + i);
162 offsets.push_back((i - offset) / stride);
173 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174 off += stride, ++idx) {
176 Value extractedSource =
extractOne(rewriter, loc, op.getSource(), idx);
177 if (isa<VectorType>(extractedSource.
getType())) {
183 extractedSource = rewriter.
create<InsertStridedSliceOp>(
184 loc, extractedSource, extractedDest,
189 res =
insertOne(rewriter, loc, extractedSource, res, off);
206 auto dstType = op.getType();
208 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
211 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
212 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
214 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
219 if (op.getOffsets().getValue().size() != 1)
223 offsets.reserve(size);
224 for (int64_t off = offset, e = offset + size * stride; off < e;
226 offsets.push_back(off);
242 std::function<
bool(ExtractStridedSliceOp)> controlFn,
248 if (controlFn && !controlFn(op))
252 if (op.getOffsets().getValue().size() != 1)
256 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
257 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
259 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
263 elements.reserve(size);
264 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
265 elements.push_back(rewriter.
create<ExtractOp>(loc, op.getVector(), i));
269 for (int64_t i = 0; i < size; ++i)
270 result = rewriter.
create<InsertOp>(loc, elements[i], result, i);
277 std::function<bool(ExtractStridedSliceOp)> controlFn;
291 setHasBoundedRewriteRecursion();
296 auto dstType = op.getType();
298 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
301 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
302 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
304 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
307 auto elemType = dstType.getElementType();
308 assert(elemType.isSignlessIntOrIndexOrFloat());
312 if (op.getOffsets().getValue().size() == 1)
318 Value res = rewriter.
create<SplatOp>(loc, dstType, zero);
319 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
320 off += stride, ++idx) {
322 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
326 res =
insertOne(rewriter, loc, extracted, res, idx);
341 std::function<
bool(ExtractStridedSliceOp)> controlFn,
344 patterns.
getContext(), std::move(controlFn), benefit);
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...