MLIR  17.0.0git
VectorInsertExtractStridedSliceRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 // Helper that picks the proper sequence for inserting.
22 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
23  Value into, int64_t offset) {
24  auto vectorType = into.getType().cast<VectorType>();
25  if (vectorType.getRank() > 1)
26  return rewriter.create<InsertOp>(loc, from, into, offset);
27  return rewriter.create<vector::InsertElementOp>(
28  loc, vectorType, from, into,
29  rewriter.create<arith::ConstantIndexOp>(loc, offset));
30 }
31 
32 // Helper that picks the proper sequence for extracting.
33 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
34  int64_t offset) {
35  auto vectorType = vector.getType().cast<VectorType>();
36  if (vectorType.getRank() > 1)
37  return rewriter.create<ExtractOp>(loc, vector, offset);
38  return rewriter.create<vector::ExtractElementOp>(
39  loc, vectorType.getElementType(), vector,
40  rewriter.create<arith::ConstantIndexOp>(loc, offset));
41 }
42 
43 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
44 /// have different ranks.
45 ///
46 /// When ranks are different, InsertStridedSlice needs to extract a properly
47 /// ranked vector from the destination vector into which to insert. This pattern
48 /// only takes care of this extraction part and forwards the rest to
49 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
50 ///
51 /// For a k-D source and n-D destination vector (k < n), we emit:
52 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
53 /// insert the k-D source.
54 /// 2. k-D -> (n-1)-D InsertStridedSlice op
55 /// 3. InsertOp that is the reverse of 1.
57  : public OpRewritePattern<InsertStridedSliceOp> {
58 public:
60 
61  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
62  PatternRewriter &rewriter) const override {
63  auto srcType = op.getSourceVectorType();
64  auto dstType = op.getDestVectorType();
65 
66  if (op.getOffsets().getValue().empty())
67  return failure();
68 
69  auto loc = op.getLoc();
70  int64_t rankDiff = dstType.getRank() - srcType.getRank();
71  assert(rankDiff >= 0);
72  if (rankDiff == 0)
73  return failure();
74 
75  int64_t rankRest = dstType.getRank() - rankDiff;
76  // Extract / insert the subvector of matching rank and InsertStridedSlice
77  // on it.
78  Value extracted = rewriter.create<ExtractOp>(
79  loc, op.getDest(),
80  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
81  /*dropBack=*/rankRest));
82 
83  // A different pattern will kick in for InsertStridedSlice with matching
84  // ranks.
85  auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
86  loc, op.getSource(), extracted,
87  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
88  getI64SubArray(op.getStrides(), /*dropFront=*/0));
89 
90  rewriter.replaceOpWithNewOp<InsertOp>(
91  op, stridedSliceInnerOp.getResult(), op.getDest(),
92  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
93  /*dropBack=*/rankRest));
94  return success();
95  }
96 };
97 
98 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
99 /// have the same rank. For each outermost index in the slice:
100 /// begin end stride
101 /// [offset : offset+size*stride : stride]
102 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
103 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
104 /// 3. the destination subvector is inserted back in the proper place
105 /// 3. InsertOp that is the reverse of 1.
107  : public OpRewritePattern<InsertStridedSliceOp> {
108 public:
110 
111  void initialize() {
112  // This pattern creates recursive InsertStridedSliceOp, but the recursion is
113  // bounded as the rank is strictly decreasing.
114  setHasBoundedRewriteRecursion();
115  }
116 
117  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
118  PatternRewriter &rewriter) const override {
119  auto srcType = op.getSourceVectorType();
120  auto dstType = op.getDestVectorType();
121 
122  if (op.getOffsets().getValue().empty())
123  return failure();
124 
125  int64_t srcRank = srcType.getRank();
126  int64_t dstRank = dstType.getRank();
127  assert(dstRank >= srcRank);
128  if (dstRank != srcRank)
129  return failure();
130 
131  if (srcType == dstType) {
132  rewriter.replaceOp(op, op.getSource());
133  return success();
134  }
135 
136  int64_t offset =
137  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
138  int64_t size = srcType.getShape().front();
139  int64_t stride =
140  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
141 
142  auto loc = op.getLoc();
143  Value res = op.getDest();
144 
145  if (srcRank == 1) {
146  int nSrc = srcType.getShape().front();
147  int nDest = dstType.getShape().front();
148  // 1. Scale source to destType so we can shufflevector them together.
149  SmallVector<int64_t> offsets(nDest, 0);
150  for (int64_t i = 0; i < nSrc; ++i)
151  offsets[i] = i;
152  Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
153  op.getSource(), offsets);
154 
155  // 2. Create a mask where we take the value from scaledSource of dest
156  // depending on the offset.
157  offsets.clear();
158  for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159  if (i < offset || i >= e || (i - offset) % stride != 0)
160  offsets.push_back(nDest + i);
161  else
162  offsets.push_back((i - offset) / stride);
163  }
164 
165  // 3. Replace with a ShuffleOp.
166  rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
167  offsets);
168 
169  return success();
170  }
171 
172  // For each slice of the source vector along the most major dimension.
173  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174  off += stride, ++idx) {
175  // 1. extract the proper subvector (or element) from source
176  Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
177  if (extractedSource.getType().isa<VectorType>()) {
178  // 2. If we have a vector, extract the proper subvector from destination
179  // Otherwise we are at the element level and no need to recurse.
180  Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
181  // 3. Reduce the problem to lowering a new InsertStridedSlice op with
182  // smaller rank.
183  extractedSource = rewriter.create<InsertStridedSliceOp>(
184  loc, extractedSource, extractedDest,
185  getI64SubArray(op.getOffsets(), /* dropFront=*/1),
186  getI64SubArray(op.getStrides(), /* dropFront=*/1));
187  }
188  // 4. Insert the extractedSource into the res vector.
189  res = insertOne(rewriter, loc, extractedSource, res, off);
190  }
191 
192  rewriter.replaceOp(op, res);
193  return success();
194  }
195 };
196 
197 /// RewritePattern for ExtractStridedSliceOp where source and destination
198 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
200  : public OpRewritePattern<ExtractStridedSliceOp> {
201 public:
203 
204  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
205  PatternRewriter &rewriter) const override {
206  auto dstType = op.getType();
207 
208  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
209 
210  int64_t offset =
211  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
212  int64_t size =
213  op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
214  int64_t stride =
215  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
216 
217  assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
218 
219  // Single offset can be more efficiently shuffled.
220  if (op.getOffsets().getValue().size() != 1)
221  return failure();
222 
223  SmallVector<int64_t, 4> offsets;
224  offsets.reserve(size);
225  for (int64_t off = offset, e = offset + size * stride; off < e;
226  off += stride)
227  offsets.push_back(off);
228  rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
229  op.getVector(),
230  rewriter.getI64ArrayAttr(offsets));
231  return success();
232  }
233 };
234 
235 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
236 /// to extract each element from the source, and then a chain of Insert ops
237 /// to insert to the target vector.
239  : public OpRewritePattern<ExtractStridedSliceOp> {
240 public:
242  MLIRContext *context,
243  std::function<bool(ExtractStridedSliceOp)> controlFn,
244  PatternBenefit benefit)
245  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
246 
247  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
248  PatternRewriter &rewriter) const override {
249  if (controlFn && !controlFn(op))
250  return failure();
251 
252  // Only handle 1-D cases.
253  if (op.getOffsets().getValue().size() != 1)
254  return failure();
255 
256  int64_t offset =
257  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
258  int64_t size =
259  op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
260  int64_t stride =
261  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
262 
263  Location loc = op.getLoc();
264  SmallVector<Value> elements;
265  elements.reserve(size);
266  for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
267  elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
268 
269  Value result = rewriter.create<arith::ConstantOp>(
270  loc, rewriter.getZeroAttr(op.getType()));
271  for (int64_t i = 0; i < size; ++i)
272  result = rewriter.create<InsertOp>(loc, elements[i], result, i);
273 
274  rewriter.replaceOp(op, result);
275  return success();
276  }
277 
278 private:
279  std::function<bool(ExtractStridedSliceOp)> controlFn;
280 };
281 
282 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
283 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
284 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
286  : public OpRewritePattern<ExtractStridedSliceOp> {
287 public:
289 
290  void initialize() {
291  // This pattern creates recursive ExtractStridedSliceOp, but the recursion
292  // is bounded as the rank is strictly decreasing.
293  setHasBoundedRewriteRecursion();
294  }
295 
296  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
297  PatternRewriter &rewriter) const override {
298  auto dstType = op.getType();
299 
300  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
301 
302  int64_t offset =
303  op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
304  int64_t size =
305  op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
306  int64_t stride =
307  op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
308 
309  auto loc = op.getLoc();
310  auto elemType = dstType.getElementType();
311  assert(elemType.isSignlessIntOrIndexOrFloat());
312 
313  // Single offset can be more efficiently shuffled. It's handled in
314  // Convert1DExtractStridedSliceIntoShuffle.
315  if (op.getOffsets().getValue().size() == 1)
316  return failure();
317 
318  // Extract/insert on a lower ranked extract strided slice op.
319  Value zero = rewriter.create<arith::ConstantOp>(
320  loc, elemType, rewriter.getZeroAttr(elemType));
321  Value res = rewriter.create<SplatOp>(loc, dstType, zero);
322  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
323  off += stride, ++idx) {
324  Value one = extractOne(rewriter, loc, op.getVector(), off);
325  Value extracted = rewriter.create<ExtractStridedSliceOp>(
326  loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
327  getI64SubArray(op.getSizes(), /* dropFront=*/1),
328  getI64SubArray(op.getStrides(), /* dropFront=*/1));
329  res = insertOne(rewriter, loc, extracted, res, idx);
330  }
331  rewriter.replaceOp(op, res);
332  return success();
333  }
334 };
335 
337  RewritePatternSet &patterns, PatternBenefit benefit) {
339  DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
340 }
341 
343  RewritePatternSet &patterns,
344  std::function<bool(ExtractStridedSliceOp)> controlFn,
345  PatternBenefit benefit) {
347  patterns.getContext(), std::move(controlFn), benefit);
348 }
349 
350 /// Populate the given list with patterns that convert from Vector to LLVM.
352  RewritePatternSet &patterns, PatternBenefit benefit) {
354  benefit);
357  benefit);
358 }
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset)
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset)
For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops to extract each element fro...
Convert1DExtractStridedSliceIntoExtractInsertChain(MLIRContext *context, std::function< bool(ExtractStridedSliceOp)> controlFn, PatternBenefit benefit)
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where source and destination vectors are 1-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:318
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:274
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
U cast() const
Definition: Types.h:321
bool isa() const
Definition: Types.h:301
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357