MLIR  19.0.0git
VectorInsertExtractStridedSliceRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 // Helper that picks the proper sequence for inserting.
22 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
23  Value into, int64_t offset) {
24  auto vectorType = cast<VectorType>(into.getType());
25  if (vectorType.getRank() > 1)
26  return rewriter.create<InsertOp>(loc, from, into, offset);
27  return rewriter.create<vector::InsertElementOp>(
28  loc, vectorType, from, into,
29  rewriter.create<arith::ConstantIndexOp>(loc, offset));
30 }
31 
32 // Helper that picks the proper sequence for extracting.
33 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
34  int64_t offset) {
35  auto vectorType = cast<VectorType>(vector.getType());
36  if (vectorType.getRank() > 1)
37  return rewriter.create<ExtractOp>(loc, vector, offset);
38  return rewriter.create<vector::ExtractElementOp>(
39  loc, vectorType.getElementType(), vector,
40  rewriter.create<arith::ConstantIndexOp>(loc, offset));
41 }
42 
43 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
44 /// have different ranks.
45 ///
46 /// When ranks are different, InsertStridedSlice needs to extract a properly
47 /// ranked vector from the destination vector into which to insert. This pattern
48 /// only takes care of this extraction part and forwards the rest to
49 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
50 ///
51 /// For a k-D source and n-D destination vector (k < n), we emit:
52 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
53 /// insert the k-D source.
54 /// 2. k-D -> (n-1)-D InsertStridedSlice op
55 /// 3. InsertOp that is the reverse of 1.
57  : public OpRewritePattern<InsertStridedSliceOp> {
58 public:
60 
61  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
62  PatternRewriter &rewriter) const override {
63  auto srcType = op.getSourceVectorType();
64  auto dstType = op.getDestVectorType();
65 
66  if (op.getOffsets().getValue().empty())
67  return failure();
68 
69  auto loc = op.getLoc();
70  int64_t rankDiff = dstType.getRank() - srcType.getRank();
71  assert(rankDiff >= 0);
72  if (rankDiff == 0)
73  return failure();
74 
75  int64_t rankRest = dstType.getRank() - rankDiff;
76  // Extract / insert the subvector of matching rank and InsertStridedSlice
77  // on it.
78  Value extracted = rewriter.create<ExtractOp>(
79  loc, op.getDest(),
80  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
81  /*dropBack=*/rankRest));
82 
83  // A different pattern will kick in for InsertStridedSlice with matching
84  // ranks.
85  auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
86  loc, op.getSource(), extracted,
87  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
88  getI64SubArray(op.getStrides(), /*dropFront=*/0));
89 
90  rewriter.replaceOpWithNewOp<InsertOp>(
91  op, stridedSliceInnerOp.getResult(), op.getDest(),
92  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
93  /*dropBack=*/rankRest));
94  return success();
95  }
96 };
97 
98 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
99 /// have the same rank. For each outermost index in the slice:
100 /// begin end stride
101 /// [offset : offset+size*stride : stride]
102 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
103 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
104 /// 3. the destination subvector is inserted back in the proper place
105 /// 3. InsertOp that is the reverse of 1.
107  : public OpRewritePattern<InsertStridedSliceOp> {
108 public:
110 
111  void initialize() {
112  // This pattern creates recursive InsertStridedSliceOp, but the recursion is
113  // bounded as the rank is strictly decreasing.
114  setHasBoundedRewriteRecursion();
115  }
116 
117  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
118  PatternRewriter &rewriter) const override {
119  auto srcType = op.getSourceVectorType();
120  auto dstType = op.getDestVectorType();
121 
122  if (op.getOffsets().getValue().empty())
123  return failure();
124 
125  int64_t srcRank = srcType.getRank();
126  int64_t dstRank = dstType.getRank();
127  assert(dstRank >= srcRank);
128  if (dstRank != srcRank)
129  return failure();
130 
131  if (srcType == dstType) {
132  rewriter.replaceOp(op, op.getSource());
133  return success();
134  }
135 
136  int64_t offset =
137  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
138  int64_t size = srcType.getShape().front();
139  int64_t stride =
140  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
141 
142  auto loc = op.getLoc();
143  Value res = op.getDest();
144 
145  if (srcRank == 1) {
146  int nSrc = srcType.getShape().front();
147  int nDest = dstType.getShape().front();
148  // 1. Scale source to destType so we can shufflevector them together.
149  SmallVector<int64_t> offsets(nDest, 0);
150  for (int64_t i = 0; i < nSrc; ++i)
151  offsets[i] = i;
152  Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
153  op.getSource(), offsets);
154 
155  // 2. Create a mask where we take the value from scaledSource of dest
156  // depending on the offset.
157  offsets.clear();
158  for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
159  if (i < offset || i >= e || (i - offset) % stride != 0)
160  offsets.push_back(nDest + i);
161  else
162  offsets.push_back((i - offset) / stride);
163  }
164 
165  // 3. Replace with a ShuffleOp.
166  rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
167  offsets);
168 
169  return success();
170  }
171 
172  // For each slice of the source vector along the most major dimension.
173  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174  off += stride, ++idx) {
175  // 1. extract the proper subvector (or element) from source
176  Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
177  if (isa<VectorType>(extractedSource.getType())) {
178  // 2. If we have a vector, extract the proper subvector from destination
179  // Otherwise we are at the element level and no need to recurse.
180  Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
181  // 3. Reduce the problem to lowering a new InsertStridedSlice op with
182  // smaller rank.
183  extractedSource = rewriter.create<InsertStridedSliceOp>(
184  loc, extractedSource, extractedDest,
185  getI64SubArray(op.getOffsets(), /* dropFront=*/1),
186  getI64SubArray(op.getStrides(), /* dropFront=*/1));
187  }
188  // 4. Insert the extractedSource into the res vector.
189  res = insertOne(rewriter, loc, extractedSource, res, off);
190  }
191 
192  rewriter.replaceOp(op, res);
193  return success();
194  }
195 };
196 
197 /// RewritePattern for ExtractStridedSliceOp where source and destination
198 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
200  : public OpRewritePattern<ExtractStridedSliceOp> {
201 public:
203 
204  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
205  PatternRewriter &rewriter) const override {
206  auto dstType = op.getType();
207 
208  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
209 
210  int64_t offset =
211  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
212  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
213  int64_t stride =
214  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
215 
216  assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
217 
218  // Single offset can be more efficiently shuffled.
219  if (op.getOffsets().getValue().size() != 1)
220  return failure();
221 
222  SmallVector<int64_t, 4> offsets;
223  offsets.reserve(size);
224  for (int64_t off = offset, e = offset + size * stride; off < e;
225  off += stride)
226  offsets.push_back(off);
227  rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228  op.getVector(),
229  rewriter.getI64ArrayAttr(offsets));
230  return success();
231  }
232 };
233 
234 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
235 /// to extract each element from the source, and then a chain of Insert ops
236 /// to insert to the target vector.
238  : public OpRewritePattern<ExtractStridedSliceOp> {
239 public:
241  MLIRContext *context,
242  std::function<bool(ExtractStridedSliceOp)> controlFn,
243  PatternBenefit benefit)
244  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
245 
246  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
247  PatternRewriter &rewriter) const override {
248  if (controlFn && !controlFn(op))
249  return failure();
250 
251  // Only handle 1-D cases.
252  if (op.getOffsets().getValue().size() != 1)
253  return failure();
254 
255  int64_t offset =
256  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
257  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
258  int64_t stride =
259  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
260 
261  Location loc = op.getLoc();
262  SmallVector<Value> elements;
263  elements.reserve(size);
264  for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
265  elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
266 
267  Value result = rewriter.create<arith::ConstantOp>(
268  loc, rewriter.getZeroAttr(op.getType()));
269  for (int64_t i = 0; i < size; ++i)
270  result = rewriter.create<InsertOp>(loc, elements[i], result, i);
271 
272  rewriter.replaceOp(op, result);
273  return success();
274  }
275 
276 private:
277  std::function<bool(ExtractStridedSliceOp)> controlFn;
278 };
279 
280 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
281 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
282 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
284  : public OpRewritePattern<ExtractStridedSliceOp> {
285 public:
287 
288  void initialize() {
289  // This pattern creates recursive ExtractStridedSliceOp, but the recursion
290  // is bounded as the rank is strictly decreasing.
291  setHasBoundedRewriteRecursion();
292  }
293 
294  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
295  PatternRewriter &rewriter) const override {
296  auto dstType = op.getType();
297 
298  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
299 
300  int64_t offset =
301  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
302  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
303  int64_t stride =
304  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
305 
306  auto loc = op.getLoc();
307  auto elemType = dstType.getElementType();
308  assert(elemType.isSignlessIntOrIndexOrFloat());
309 
310  // Single offset can be more efficiently shuffled. It's handled in
311  // Convert1DExtractStridedSliceIntoShuffle.
312  if (op.getOffsets().getValue().size() == 1)
313  return failure();
314 
315  // Extract/insert on a lower ranked extract strided slice op.
316  Value zero = rewriter.create<arith::ConstantOp>(
317  loc, elemType, rewriter.getZeroAttr(elemType));
318  Value res = rewriter.create<SplatOp>(loc, dstType, zero);
319  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
320  off += stride, ++idx) {
321  Value one = extractOne(rewriter, loc, op.getVector(), off);
322  Value extracted = rewriter.create<ExtractStridedSliceOp>(
323  loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
324  getI64SubArray(op.getSizes(), /* dropFront=*/1),
325  getI64SubArray(op.getStrides(), /* dropFront=*/1));
326  res = insertOne(rewriter, loc, extracted, res, idx);
327  }
328  rewriter.replaceOp(op, res);
329  return success();
330  }
331 };
332 
334  RewritePatternSet &patterns, PatternBenefit benefit) {
336  DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
337 }
338 
340  RewritePatternSet &patterns,
341  std::function<bool(ExtractStridedSliceOp)> controlFn,
342  PatternBenefit benefit) {
344  patterns.getContext(), std::move(controlFn), benefit);
345 }
346 
347 /// Populate the given list with patterns that convert from Vector to LLVM.
349  RewritePatternSet &patterns, PatternBenefit benefit) {
351  benefit);
354  benefit);
355 }
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset)
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset)
For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops to extract each element fro...
Convert1DExtractStridedSliceIntoExtractInsertChain(MLIRContext *context, std::function< bool(ExtractStridedSliceOp)> controlFn, PatternBenefit benefit)
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where source and destination vectors are 1-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358