MLIR  16.0.0git
VectorInsertExtractStridedSliceRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/BuiltinTypes.h"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 // Helper that picks the proper sequence for inserting.
21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
22  Value into, int64_t offset) {
23  auto vectorType = into.getType().cast<VectorType>();
24  if (vectorType.getRank() > 1)
25  return rewriter.create<InsertOp>(loc, from, into, offset);
26  return rewriter.create<vector::InsertElementOp>(
27  loc, vectorType, from, into,
28  rewriter.create<arith::ConstantIndexOp>(loc, offset));
29 }
30 
31 // Helper that picks the proper sequence for extracting.
32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
33  int64_t offset) {
34  auto vectorType = vector.getType().cast<VectorType>();
35  if (vectorType.getRank() > 1)
36  return rewriter.create<ExtractOp>(loc, vector, offset);
37  return rewriter.create<vector::ExtractElementOp>(
38  loc, vectorType.getElementType(), vector,
39  rewriter.create<arith::ConstantIndexOp>(loc, offset));
40 }
41 
42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
43 /// have different ranks.
44 ///
45 /// When ranks are different, InsertStridedSlice needs to extract a properly
46 /// ranked vector from the destination vector into which to insert. This pattern
47 /// only takes care of this extraction part and forwards the rest to
48 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
49 ///
50 /// For a k-D source and n-D destination vector (k < n), we emit:
51 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
52 /// insert the k-D source.
53 /// 2. k-D -> (n-1)-D InsertStridedSlice op
54 /// 3. InsertOp that is the reverse of 1.
56  : public OpRewritePattern<InsertStridedSliceOp> {
57 public:
59 
60  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
61  PatternRewriter &rewriter) const override {
62  auto srcType = op.getSourceVectorType();
63  auto dstType = op.getDestVectorType();
64 
65  if (op.getOffsets().getValue().empty())
66  return failure();
67 
68  auto loc = op.getLoc();
69  int64_t rankDiff = dstType.getRank() - srcType.getRank();
70  assert(rankDiff >= 0);
71  if (rankDiff == 0)
72  return failure();
73 
74  int64_t rankRest = dstType.getRank() - rankDiff;
75  // Extract / insert the subvector of matching rank and InsertStridedSlice
76  // on it.
77  Value extracted = rewriter.create<ExtractOp>(
78  loc, op.getDest(),
79  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
80  /*dropBack=*/rankRest));
81 
82  // A different pattern will kick in for InsertStridedSlice with matching
83  // ranks.
84  auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
85  loc, op.getSource(), extracted,
86  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
87  getI64SubArray(op.getStrides(), /*dropFront=*/0));
88 
89  rewriter.replaceOpWithNewOp<InsertOp>(
90  op, stridedSliceInnerOp.getResult(), op.getDest(),
91  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
92  /*dropBack=*/rankRest));
93  return success();
94  }
95 };
96 
97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
98 /// have the same rank. For each outermost index in the slice:
99 /// begin end stride
100 /// [offset : offset+size*stride : stride]
101 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
102 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
103 /// 3. the destination subvector is inserted back in the proper place
104 /// 3. InsertOp that is the reverse of 1.
106  : public OpRewritePattern<InsertStridedSliceOp> {
107 public:
109 
110  void initialize() {
111  // This pattern creates recursive InsertStridedSliceOp, but the recursion is
112  // bounded as the rank is strictly decreasing.
113  setHasBoundedRewriteRecursion();
114  }
115 
116  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
117  PatternRewriter &rewriter) const override {
118  auto srcType = op.getSourceVectorType();
119  auto dstType = op.getDestVectorType();
120 
121  if (op.getOffsets().getValue().empty())
122  return failure();
123 
124  int64_t srcRank = srcType.getRank();
125  int64_t dstRank = dstType.getRank();
126  assert(dstRank >= srcRank);
127  if (dstRank != srcRank)
128  return failure();
129 
130  if (srcType == dstType) {
131  rewriter.replaceOp(op, op.getSource());
132  return success();
133  }
134 
135  int64_t offset =
136  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
137  int64_t size = srcType.getShape().front();
138  int64_t stride =
139  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
140 
141  auto loc = op.getLoc();
142  Value res = op.getDest();
143 
144  if (srcRank == 1) {
145  int nSrc = srcType.getShape().front();
146  int nDest = dstType.getShape().front();
147  // 1. Scale source to destType so we can shufflevector them together.
148  SmallVector<int64_t> offsets(nDest, 0);
149  for (int64_t i = 0; i < nSrc; ++i)
150  offsets[i] = i;
151  Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
152  op.getSource(), offsets);
153 
154  // 2. Create a mask where we take the value from scaledSource of dest
155  // depending on the offset.
156  offsets.clear();
157  for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
158  if (i < offset || i >= e || (i - offset) % stride != 0)
159  offsets.push_back(nDest + i);
160  else
161  offsets.push_back((i - offset) / stride);
162  }
163 
164  // 3. Replace with a ShuffleOp.
165  rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
166  offsets);
167 
168  return success();
169  }
170 
171  // For each slice of the source vector along the most major dimension.
172  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
173  off += stride, ++idx) {
174  // 1. extract the proper subvector (or element) from source
175  Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
176  if (extractedSource.getType().isa<VectorType>()) {
177  // 2. If we have a vector, extract the proper subvector from destination
178  // Otherwise we are at the element level and no need to recurse.
179  Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
180  // 3. Reduce the problem to lowering a new InsertStridedSlice op with
181  // smaller rank.
182  extractedSource = rewriter.create<InsertStridedSliceOp>(
183  loc, extractedSource, extractedDest,
184  getI64SubArray(op.getOffsets(), /* dropFront=*/1),
185  getI64SubArray(op.getStrides(), /* dropFront=*/1));
186  }
187  // 4. Insert the extractedSource into the res vector.
188  res = insertOne(rewriter, loc, extractedSource, res, off);
189  }
190 
191  rewriter.replaceOp(op, res);
192  return success();
193  }
194 };
195 
196 /// RewritePattern for ExtractStridedSliceOp where source and destination
197 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
199  : public OpRewritePattern<ExtractStridedSliceOp> {
200 public:
202 
203  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
204  PatternRewriter &rewriter) const override {
205  auto dstType = op.getType();
206 
207  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
208 
209  int64_t offset =
210  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
211  int64_t size =
212  op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
213  int64_t stride =
214  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
215 
216  assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
217 
218  // Single offset can be more efficiently shuffled.
219  if (op.getOffsets().getValue().size() != 1)
220  return failure();
221 
222  SmallVector<int64_t, 4> offsets;
223  offsets.reserve(size);
224  for (int64_t off = offset, e = offset + size * stride; off < e;
225  off += stride)
226  offsets.push_back(off);
227  rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228  op.getVector(),
229  rewriter.getI64ArrayAttr(offsets));
230  return success();
231  }
232 };
233 
234 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
235 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
236 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
238  : public OpRewritePattern<ExtractStridedSliceOp> {
239 public:
241 
242  void initialize() {
243  // This pattern creates recursive ExtractStridedSliceOp, but the recursion
244  // is bounded as the rank is strictly decreasing.
245  setHasBoundedRewriteRecursion();
246  }
247 
248  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
249  PatternRewriter &rewriter) const override {
250  auto dstType = op.getType();
251 
252  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
253 
254  int64_t offset =
255  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
256  int64_t size =
257  op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
258  int64_t stride =
259  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
260 
261  auto loc = op.getLoc();
262  auto elemType = dstType.getElementType();
263  assert(elemType.isSignlessIntOrIndexOrFloat());
264 
265  // Single offset can be more efficiently shuffled. It's handled in
266  // Convert1DExtractStridedSliceIntoShuffle.
267  if (op.getOffsets().getValue().size() == 1)
268  return failure();
269 
270  // Extract/insert on a lower ranked extract strided slice op.
271  Value zero = rewriter.create<arith::ConstantOp>(
272  loc, elemType, rewriter.getZeroAttr(elemType));
273  Value res = rewriter.create<SplatOp>(loc, dstType, zero);
274  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
275  off += stride, ++idx) {
276  Value one = extractOne(rewriter, loc, op.getVector(), off);
277  Value extracted = rewriter.create<ExtractStridedSliceOp>(
278  loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
279  getI64SubArray(op.getSizes(), /* dropFront=*/1),
280  getI64SubArray(op.getStrides(), /* dropFront=*/1));
281  res = insertOne(rewriter, loc, extracted, res, idx);
282  }
283  rewriter.replaceOp(op, res);
284  return success();
285  }
286 };
287 
289  RewritePatternSet &patterns) {
292 }
293 
294 /// Populate the given list with patterns that convert from Vector to LLVM.
296  RewritePatternSet &patterns) {
300 }
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
SmallVector< int64_t, 4 > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper that returns a subset of arrayAttr as a vector of int64_t.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:244
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns)
Populate patterns with the following patterns.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset)
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns)
Populate patterns with the following patterns.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks...
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank...
bool isa() const
Definition: Types.h:254
RewritePattern for ExtractStridedSliceOp where source and destination vectors are 1-D...
MLIRContext * getContext() const
U cast() const
Definition: Types.h:278