22 Value into, int64_t offset) {
25 return rewriter.
create<InsertOp>(loc, from, into, offset);
26 return rewriter.
create<vector::InsertElementOp>(
36 return rewriter.
create<ExtractOp>(loc, vector, offset);
37 return rewriter.
create<vector::ExtractElementOp>(
62 auto srcType = op.getSourceVectorType();
63 auto dstType = op.getDestVectorType();
65 if (op.getOffsets().getValue().empty())
68 auto loc = op.getLoc();
69 int64_t rankDiff = dstType.getRank() - srcType.getRank();
70 assert(rankDiff >= 0);
74 int64_t rankRest = dstType.getRank() - rankDiff;
84 auto stridedSliceInnerOp = rewriter.
create<InsertStridedSliceOp>(
85 loc, op.getSource(), extracted,
90 op, stridedSliceInnerOp.getResult(), op.getDest(),
113 setHasBoundedRewriteRecursion();
118 auto srcType = op.getSourceVectorType();
119 auto dstType = op.getDestVectorType();
121 if (op.getOffsets().getValue().empty())
124 int64_t srcRank = srcType.getRank();
125 int64_t dstRank = dstType.getRank();
126 assert(dstRank >= srcRank);
127 if (dstRank != srcRank)
130 if (srcType == dstType) {
136 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
137 int64_t size = srcType.getShape().front();
139 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
141 auto loc = op.getLoc();
142 Value res = op.getDest();
145 int nSrc = srcType.getShape().front();
146 int nDest = dstType.getShape().front();
149 for (int64_t i = 0; i < nSrc; ++i)
151 Value scaledSource = rewriter.
create<ShuffleOp>(loc, op.getSource(),
152 op.getSource(), offsets);
157 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
158 if (i < offset || i >= e || (i - offset) % stride != 0)
159 offsets.push_back(nDest + i);
161 offsets.push_back((i - offset) / stride);
172 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
173 off += stride, ++idx) {
175 Value extractedSource =
extractOne(rewriter, loc, op.getSource(), idx);
176 if (extractedSource.
getType().
isa<VectorType>()) {
182 extractedSource = rewriter.
create<InsertStridedSliceOp>(
183 loc, extractedSource, extractedDest,
188 res =
insertOne(rewriter, loc, extractedSource, res, off);
205 auto dstType = op.getType();
207 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
210 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
212 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
214 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
219 if (op.getOffsets().getValue().size() != 1)
223 offsets.reserve(size);
224 for (int64_t off = offset, e = offset + size * stride; off < e;
226 offsets.push_back(off);
245 setHasBoundedRewriteRecursion();
250 auto dstType = op.getType();
252 assert(!op.getOffsets().getValue().empty() &&
"Unexpected empty offsets");
255 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
257 op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
259 op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
261 auto loc = op.getLoc();
262 auto elemType = dstType.getElementType();
263 assert(elemType.isSignlessIntOrIndexOrFloat());
267 if (op.getOffsets().getValue().size() == 1)
273 Value res = rewriter.
create<SplatOp>(loc, dstType, zero);
274 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
275 off += stride, ++idx) {
277 Value extracted = rewriter.
create<ExtractStridedSliceOp>(
281 res =
insertOne(rewriter, loc, extracted, res, idx);
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Attribute getZeroAttr(Type type)
SmallVector< int64_t, 4 > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper that returns a subset of arrayAttr as a vector of int64_t.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns)
Populate patterns with the following patterns.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns)
Populate patterns with the following patterns.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Type getType() const
Return the type of this value.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks...
Specialization of arith.constant op that returns an integer of index type.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank...
MLIRContext * getContext() const