MLIR  21.0.0git
VectorInsertExtractStridedSliceRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
22 /// have different ranks.
23 ///
24 /// When ranks are different, InsertStridedSlice needs to extract a properly
25 /// ranked vector from the destination vector into which to insert. This pattern
26 /// only takes care of this extraction part and forwards the rest to
27 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
28 ///
29 /// For a k-D source and n-D destination vector (k < n), we emit:
30 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
31 /// insert the k-D source.
32 /// 2. k-D -> (n-1)-D InsertStridedSlice op
33 /// 3. InsertOp that is the reverse of 1.
35  : public OpRewritePattern<InsertStridedSliceOp> {
36 public:
38 
39  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
40  PatternRewriter &rewriter) const override {
41  auto srcType = op.getSourceVectorType();
42  auto dstType = op.getDestVectorType();
43 
44  if (op.getOffsets().getValue().empty())
45  return failure();
46 
47  auto loc = op.getLoc();
48  int64_t rankDiff = dstType.getRank() - srcType.getRank();
49  assert(rankDiff >= 0);
50  if (rankDiff == 0)
51  return failure();
52 
53  int64_t rankRest = dstType.getRank() - rankDiff;
54  // Extract / insert the subvector of matching rank and InsertStridedSlice
55  // on it.
56  Value extracted = rewriter.create<ExtractOp>(
57  loc, op.getDest(),
58  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
59  /*dropBack=*/rankRest));
60 
61  // A different pattern will kick in for InsertStridedSlice with matching
62  // ranks.
63  auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
64  loc, op.getSource(), extracted,
65  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
66  getI64SubArray(op.getStrides(), /*dropFront=*/0));
67 
68  rewriter.replaceOpWithNewOp<InsertOp>(
69  op, stridedSliceInnerOp.getResult(), op.getDest(),
70  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
71  /*dropBack=*/rankRest));
72  return success();
73  }
74 };
75 
76 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
77 /// have the same rank. For each outermost index in the slice:
78 /// begin end stride
79 /// [offset : offset+size*stride : stride]
80 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
81 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
82 /// 3. the destination subvector is inserted back in the proper place
83 /// 3. InsertOp that is the reverse of 1.
85  : public OpRewritePattern<InsertStridedSliceOp> {
86 public:
88 
89  void initialize() {
90  // This pattern creates recursive InsertStridedSliceOp, but the recursion is
91  // bounded as the rank is strictly decreasing.
92  setHasBoundedRewriteRecursion();
93  }
94 
95  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
96  PatternRewriter &rewriter) const override {
97  auto srcType = op.getSourceVectorType();
98  auto dstType = op.getDestVectorType();
99  int64_t srcRank = srcType.getRank();
100 
101  // Scalable vectors are not supported by vector shuffle.
102  if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
103  return failure();
104 
105  if (op.getOffsets().getValue().empty())
106  return failure();
107 
108  int64_t dstRank = dstType.getRank();
109  assert(dstRank >= srcRank);
110  if (dstRank != srcRank)
111  return failure();
112 
113  if (srcType == dstType) {
114  rewriter.replaceOp(op, op.getSource());
115  return success();
116  }
117 
118  int64_t offset =
119  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
120  int64_t size = srcType.getShape().front();
121  int64_t stride =
122  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
123 
124  auto loc = op.getLoc();
125  Value res = op.getDest();
126 
127  if (srcRank == 1) {
128  int nSrc = srcType.getShape().front();
129  int nDest = dstType.getShape().front();
130  // 1. Scale source to destType so we can shufflevector them together.
131  SmallVector<int64_t> offsets(nDest, 0);
132  for (int64_t i = 0; i < nSrc; ++i)
133  offsets[i] = i;
134  Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
135  op.getSource(), offsets);
136 
137  // 2. Create a mask where we take the value from scaledSource of dest
138  // depending on the offset.
139  offsets.clear();
140  for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
141  if (i < offset || i >= e || (i - offset) % stride != 0)
142  offsets.push_back(nDest + i);
143  else
144  offsets.push_back((i - offset) / stride);
145  }
146 
147  // 3. Replace with a ShuffleOp.
148  rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
149  offsets);
150 
151  return success();
152  }
153 
154  // For each slice of the source vector along the most major dimension.
155  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
156  off += stride, ++idx) {
157  // 1. extract the proper subvector (or element) from source
158  Value extractedSource =
159  rewriter.create<ExtractOp>(loc, op.getSource(), idx);
160  if (isa<VectorType>(extractedSource.getType())) {
161  // 2. If we have a vector, extract the proper subvector from destination
162  // Otherwise we are at the element level and no need to recurse.
163  Value extractedDest =
164  rewriter.create<ExtractOp>(loc, op.getDest(), off);
165  // 3. Reduce the problem to lowering a new InsertStridedSlice op with
166  // smaller rank.
167  extractedSource = rewriter.create<InsertStridedSliceOp>(
168  loc, extractedSource, extractedDest,
169  getI64SubArray(op.getOffsets(), /* dropFront=*/1),
170  getI64SubArray(op.getStrides(), /* dropFront=*/1));
171  }
172  // 4. Insert the extractedSource into the res vector.
173  res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
174  }
175 
176  rewriter.replaceOp(op, res);
177  return success();
178  }
179 };
180 
181 /// RewritePattern for ExtractStridedSliceOp where source and destination
182 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
184  : public OpRewritePattern<ExtractStridedSliceOp> {
185 public:
187 
188  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
189  PatternRewriter &rewriter) const override {
190  auto dstType = op.getType();
191  auto srcType = op.getSourceVectorType();
192 
193  // Scalable vectors are not supported by vector shuffle.
194  if (dstType.isScalable() || srcType.isScalable())
195  return failure();
196 
197  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
198 
199  int64_t offset =
200  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
201  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
202  int64_t stride =
203  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
204 
205  assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
206 
207  // Single offset can be more efficiently shuffled.
208  if (op.getOffsets().getValue().size() != 1)
209  return failure();
210 
211  SmallVector<int64_t, 4> offsets;
212  offsets.reserve(size);
213  for (int64_t off = offset, e = offset + size * stride; off < e;
214  off += stride)
215  offsets.push_back(off);
216  rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
217  op.getVector(), offsets);
218  return success();
219  }
220 };
221 
222 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
223 /// to extract each element from the source, and then a chain of Insert ops
224 /// to insert to the target vector.
226  : public OpRewritePattern<ExtractStridedSliceOp> {
227 public:
229  MLIRContext *context,
230  std::function<bool(ExtractStridedSliceOp)> controlFn,
231  PatternBenefit benefit)
232  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
233 
234  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
235  PatternRewriter &rewriter) const override {
236  if (controlFn && !controlFn(op))
237  return failure();
238 
239  // Only handle 1-D cases.
240  if (op.getOffsets().getValue().size() != 1)
241  return failure();
242 
243  int64_t offset =
244  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
245  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
246  int64_t stride =
247  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
248 
249  Location loc = op.getLoc();
250  SmallVector<Value> elements;
251  elements.reserve(size);
252  for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
253  elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
254 
255  Value result = rewriter.create<arith::ConstantOp>(
256  loc, rewriter.getZeroAttr(op.getType()));
257  for (int64_t i = 0; i < size; ++i)
258  result = rewriter.create<InsertOp>(loc, elements[i], result, i);
259 
260  rewriter.replaceOp(op, result);
261  return success();
262  }
263 
264 private:
265  std::function<bool(ExtractStridedSliceOp)> controlFn;
266 };
267 
268 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
269 /// For such cases, we can rewrite it to ExtractOp + lower rank
270 /// ExtractStridedSliceOp + InsertOp for the n-D case.
272  : public OpRewritePattern<ExtractStridedSliceOp> {
273 public:
275 
276  void initialize() {
277  // This pattern creates recursive ExtractStridedSliceOp, but the recursion
278  // is bounded as the rank is strictly decreasing.
279  setHasBoundedRewriteRecursion();
280  }
281 
282  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
283  PatternRewriter &rewriter) const override {
284  auto dstType = op.getType();
285 
286  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
287 
288  int64_t offset =
289  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
290  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
291  int64_t stride =
292  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
293 
294  auto loc = op.getLoc();
295  auto elemType = dstType.getElementType();
296  assert(elemType.isSignlessIntOrIndexOrFloat());
297 
298  // Single offset can be more efficiently shuffled. It's handled in
299  // Convert1DExtractStridedSliceIntoShuffle.
300  if (op.getOffsets().getValue().size() == 1)
301  return failure();
302 
303  // Extract/insert on a lower ranked extract strided slice op.
304  Value zero = rewriter.create<arith::ConstantOp>(
305  loc, elemType, rewriter.getZeroAttr(elemType));
306  Value res = rewriter.create<SplatOp>(loc, dstType, zero);
307  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
308  off += stride, ++idx) {
309  Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
310  Value extracted = rewriter.create<ExtractStridedSliceOp>(
311  loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
312  getI64SubArray(op.getSizes(), /* dropFront=*/1),
313  getI64SubArray(op.getStrides(), /* dropFront=*/1));
314  res = rewriter.create<InsertOp>(loc, extracted, res, idx);
315  }
316  rewriter.replaceOp(op, res);
317  return success();
318  }
319 };
320 
321 // TODO: Make sure these `populate*` patterns are tested in isolation.
322 
326  DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
327 }
328 
331  std::function<bool(ExtractStridedSliceOp)> controlFn,
332  PatternBenefit benefit) {
334  patterns.getContext(), std::move(controlFn), benefit);
335 }
336 
337 /// Populate the given list with patterns that convert from Vector to LLVM.
341  benefit);
344  benefit);
345  // Generate chains of extract/insert ops for scalable vectors only as they
346  // can't be lowered to vector shuffles.
348  patterns,
349  /*controlFn=*/
350  [](ExtractStridedSliceOp op) {
351  return op.getType().isScalable() ||
352  op.getSourceVectorType().isScalable();
353  },
354  benefit);
355 }
For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops to extract each element fro...
Convert1DExtractStridedSliceIntoExtractInsertChain(MLIRContext *context, std::function< bool(ExtractStridedSliceOp)> controlFn, PatternBenefit benefit)
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where source and destination vectors are 1-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358