MLIR  20.0.0git
VectorInsertExtractStridedSliceRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
22 /// have different ranks.
23 ///
24 /// When ranks are different, InsertStridedSlice needs to extract a properly
25 /// ranked vector from the destination vector into which to insert. This pattern
26 /// only takes care of this extraction part and forwards the rest to
27 /// [ConvertSameRankInsertStridedSliceIntoShuffle].
28 ///
29 /// For a k-D source and n-D destination vector (k < n), we emit:
30 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
31 /// insert the k-D source.
32 /// 2. k-D -> (n-1)-D InsertStridedSlice op
33 /// 3. InsertOp that is the reverse of 1.
35  : public OpRewritePattern<InsertStridedSliceOp> {
36 public:
38 
39  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
40  PatternRewriter &rewriter) const override {
41  auto srcType = op.getSourceVectorType();
42  auto dstType = op.getDestVectorType();
43 
44  if (op.getOffsets().getValue().empty())
45  return failure();
46 
47  auto loc = op.getLoc();
48  int64_t rankDiff = dstType.getRank() - srcType.getRank();
49  assert(rankDiff >= 0);
50  if (rankDiff == 0)
51  return failure();
52 
53  int64_t rankRest = dstType.getRank() - rankDiff;
54  // Extract / insert the subvector of matching rank and InsertStridedSlice
55  // on it.
56  Value extracted = rewriter.create<ExtractOp>(
57  loc, op.getDest(),
58  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
59  /*dropBack=*/rankRest));
60 
61  // A different pattern will kick in for InsertStridedSlice with matching
62  // ranks.
63  auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
64  loc, op.getSource(), extracted,
65  getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
66  getI64SubArray(op.getStrides(), /*dropFront=*/0));
67 
68  rewriter.replaceOpWithNewOp<InsertOp>(
69  op, stridedSliceInnerOp.getResult(), op.getDest(),
70  getI64SubArray(op.getOffsets(), /*dropFront=*/0,
71  /*dropBack=*/rankRest));
72  return success();
73  }
74 };
75 
76 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
77 /// have the same rank. For each outermost index in the slice:
78 /// begin end stride
79 /// [offset : offset+size*stride : stride]
80 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
81 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
82 /// 3. the destination subvector is inserted back in the proper place
83 /// 3. InsertOp that is the reverse of 1.
85  : public OpRewritePattern<InsertStridedSliceOp> {
86 public:
88 
89  void initialize() {
90  // This pattern creates recursive InsertStridedSliceOp, but the recursion is
91  // bounded as the rank is strictly decreasing.
92  setHasBoundedRewriteRecursion();
93  }
94 
95  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
96  PatternRewriter &rewriter) const override {
97  auto srcType = op.getSourceVectorType();
98  auto dstType = op.getDestVectorType();
99 
100  if (op.getOffsets().getValue().empty())
101  return failure();
102 
103  int64_t srcRank = srcType.getRank();
104  int64_t dstRank = dstType.getRank();
105  assert(dstRank >= srcRank);
106  if (dstRank != srcRank)
107  return failure();
108 
109  if (srcType == dstType) {
110  rewriter.replaceOp(op, op.getSource());
111  return success();
112  }
113 
114  int64_t offset =
115  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
116  int64_t size = srcType.getShape().front();
117  int64_t stride =
118  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
119 
120  auto loc = op.getLoc();
121  Value res = op.getDest();
122 
123  if (srcRank == 1) {
124  int nSrc = srcType.getShape().front();
125  int nDest = dstType.getShape().front();
126  // 1. Scale source to destType so we can shufflevector them together.
127  SmallVector<int64_t> offsets(nDest, 0);
128  for (int64_t i = 0; i < nSrc; ++i)
129  offsets[i] = i;
130  Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
131  op.getSource(), offsets);
132 
133  // 2. Create a mask where we take the value from scaledSource of dest
134  // depending on the offset.
135  offsets.clear();
136  for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
137  if (i < offset || i >= e || (i - offset) % stride != 0)
138  offsets.push_back(nDest + i);
139  else
140  offsets.push_back((i - offset) / stride);
141  }
142 
143  // 3. Replace with a ShuffleOp.
144  rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
145  offsets);
146 
147  return success();
148  }
149 
150  // For each slice of the source vector along the most major dimension.
151  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
152  off += stride, ++idx) {
153  // 1. extract the proper subvector (or element) from source
154  Value extractedSource =
155  rewriter.create<ExtractOp>(loc, op.getSource(), idx);
156  if (isa<VectorType>(extractedSource.getType())) {
157  // 2. If we have a vector, extract the proper subvector from destination
158  // Otherwise we are at the element level and no need to recurse.
159  Value extractedDest =
160  rewriter.create<ExtractOp>(loc, op.getDest(), off);
161  // 3. Reduce the problem to lowering a new InsertStridedSlice op with
162  // smaller rank.
163  extractedSource = rewriter.create<InsertStridedSliceOp>(
164  loc, extractedSource, extractedDest,
165  getI64SubArray(op.getOffsets(), /* dropFront=*/1),
166  getI64SubArray(op.getStrides(), /* dropFront=*/1));
167  }
168  // 4. Insert the extractedSource into the res vector.
169  res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
170  }
171 
172  rewriter.replaceOp(op, res);
173  return success();
174  }
175 };
176 
177 /// RewritePattern for ExtractStridedSliceOp where source and destination
178 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
180  : public OpRewritePattern<ExtractStridedSliceOp> {
181 public:
183 
184  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
185  PatternRewriter &rewriter) const override {
186  auto dstType = op.getType();
187 
188  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
189 
190  int64_t offset =
191  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
192  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
193  int64_t stride =
194  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
195 
196  assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
197 
198  // Single offset can be more efficiently shuffled.
199  if (op.getOffsets().getValue().size() != 1)
200  return failure();
201 
202  SmallVector<int64_t, 4> offsets;
203  offsets.reserve(size);
204  for (int64_t off = offset, e = offset + size * stride; off < e;
205  off += stride)
206  offsets.push_back(off);
207  rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
208  op.getVector(), offsets);
209  return success();
210  }
211 };
212 
213 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
214 /// to extract each element from the source, and then a chain of Insert ops
215 /// to insert to the target vector.
217  : public OpRewritePattern<ExtractStridedSliceOp> {
218 public:
220  MLIRContext *context,
221  std::function<bool(ExtractStridedSliceOp)> controlFn,
222  PatternBenefit benefit)
223  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
224 
225  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
226  PatternRewriter &rewriter) const override {
227  if (controlFn && !controlFn(op))
228  return failure();
229 
230  // Only handle 1-D cases.
231  if (op.getOffsets().getValue().size() != 1)
232  return failure();
233 
234  int64_t offset =
235  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
236  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
237  int64_t stride =
238  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
239 
240  Location loc = op.getLoc();
241  SmallVector<Value> elements;
242  elements.reserve(size);
243  for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
244  elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
245 
246  Value result = rewriter.create<arith::ConstantOp>(
247  loc, rewriter.getZeroAttr(op.getType()));
248  for (int64_t i = 0; i < size; ++i)
249  result = rewriter.create<InsertOp>(loc, elements[i], result, i);
250 
251  rewriter.replaceOp(op, result);
252  return success();
253  }
254 
255 private:
256  std::function<bool(ExtractStridedSliceOp)> controlFn;
257 };
258 
259 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
260 /// For such cases, we can rewrite it to ExtractOp + lower rank
261 /// ExtractStridedSliceOp + InsertOp for the n-D case.
263  : public OpRewritePattern<ExtractStridedSliceOp> {
264 public:
266 
267  void initialize() {
268  // This pattern creates recursive ExtractStridedSliceOp, but the recursion
269  // is bounded as the rank is strictly decreasing.
270  setHasBoundedRewriteRecursion();
271  }
272 
273  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
274  PatternRewriter &rewriter) const override {
275  auto dstType = op.getType();
276 
277  assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
278 
279  int64_t offset =
280  cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
281  int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
282  int64_t stride =
283  cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
284 
285  auto loc = op.getLoc();
286  auto elemType = dstType.getElementType();
287  assert(elemType.isSignlessIntOrIndexOrFloat());
288 
289  // Single offset can be more efficiently shuffled. It's handled in
290  // Convert1DExtractStridedSliceIntoShuffle.
291  if (op.getOffsets().getValue().size() == 1)
292  return failure();
293 
294  // Extract/insert on a lower ranked extract strided slice op.
295  Value zero = rewriter.create<arith::ConstantOp>(
296  loc, elemType, rewriter.getZeroAttr(elemType));
297  Value res = rewriter.create<SplatOp>(loc, dstType, zero);
298  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
299  off += stride, ++idx) {
300  Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
301  Value extracted = rewriter.create<ExtractStridedSliceOp>(
302  loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
303  getI64SubArray(op.getSizes(), /* dropFront=*/1),
304  getI64SubArray(op.getStrides(), /* dropFront=*/1));
305  res = rewriter.create<InsertOp>(loc, extracted, res, idx);
306  }
307  rewriter.replaceOp(op, res);
308  return success();
309  }
310 };
311 
313  RewritePatternSet &patterns, PatternBenefit benefit) {
315  DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
316 }
317 
319  RewritePatternSet &patterns,
320  std::function<bool(ExtractStridedSliceOp)> controlFn,
321  PatternBenefit benefit) {
323  patterns.getContext(), std::move(controlFn), benefit);
324 }
325 
326 /// Populate the given list with patterns that convert from Vector to LLVM.
328  RewritePatternSet &patterns, PatternBenefit benefit) {
330  benefit);
333  benefit);
334 }
For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops to extract each element fro...
Convert1DExtractStridedSliceIntoExtractInsertChain(MLIRContext *context, std::function< bool(ExtractStridedSliceOp)> controlFn, PatternBenefit benefit)
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where source and destination vectors are 1-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateVectorExtractStridedSliceToExtractInsertChainPatterns(RewritePatternSet &patterns, std::function< bool(ExtractStridedSliceOp)> controlFn=nullptr, PatternBenefit benefit=1)
Populate patterns with a pattern to breaks down 1-D extract_strided_slice ops into a chain of Extract...
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358