MLIR  14.0.0git
VectorInsertExtractStridedSliceRewritePatterns.cpp
Go to the documentation of this file.
1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/BuiltinTypes.h"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 // Helper that picks the proper sequence for inserting.
21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
22  Value into, int64_t offset) {
23  auto vectorType = into.getType().cast<VectorType>();
24  if (vectorType.getRank() > 1)
25  return rewriter.create<InsertOp>(loc, from, into, offset);
26  return rewriter.create<vector::InsertElementOp>(
27  loc, vectorType, from, into,
28  rewriter.create<arith::ConstantIndexOp>(loc, offset));
29 }
30 
31 // Helper that picks the proper sequence for extracting.
32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
33  int64_t offset) {
34  auto vectorType = vector.getType().cast<VectorType>();
35  if (vectorType.getRank() > 1)
36  return rewriter.create<ExtractOp>(loc, vector, offset);
37  return rewriter.create<vector::ExtractElementOp>(
38  loc, vectorType.getElementType(), vector,
39  rewriter.create<arith::ConstantIndexOp>(loc, offset));
40 }
41 
42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
43 /// have different ranks.
44 ///
45 /// When ranks are different, InsertStridedSlice needs to extract a properly
46 /// ranked vector from the destination vector into which to insert. This pattern
47 /// only takes care of this extraction part and forwards the rest to
48 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
49 ///
50 /// For a k-D source and n-D destination vector (k < n), we emit:
51 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
52 /// insert the k-D source.
53 /// 2. k-D -> (n-1)-D InsertStridedSlice op
54 /// 3. InsertOp that is the reverse of 1.
56  : public OpRewritePattern<InsertStridedSliceOp> {
57 public:
59 
60  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
61  PatternRewriter &rewriter) const override {
62  auto srcType = op.getSourceVectorType();
63  auto dstType = op.getDestVectorType();
64 
65  if (op.offsets().getValue().empty())
66  return failure();
67 
68  auto loc = op.getLoc();
69  int64_t rankDiff = dstType.getRank() - srcType.getRank();
70  assert(rankDiff >= 0);
71  if (rankDiff == 0)
72  return failure();
73 
74  int64_t rankRest = dstType.getRank() - rankDiff;
75  // Extract / insert the subvector of matching rank and InsertStridedSlice
76  // on it.
77  Value extracted =
78  rewriter.create<ExtractOp>(loc, op.dest(),
79  getI64SubArray(op.offsets(), /*dropFront=*/0,
80  /*dropBack=*/rankRest));
81 
82  // A different pattern will kick in for InsertStridedSlice with matching
83  // ranks.
84  auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
85  loc, op.source(), extracted,
86  getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
87  getI64SubArray(op.strides(), /*dropFront=*/0));
88 
89  rewriter.replaceOpWithNewOp<InsertOp>(
90  op, stridedSliceInnerOp.getResult(), op.dest(),
91  getI64SubArray(op.offsets(), /*dropFront=*/0,
92  /*dropBack=*/rankRest));
93  return success();
94  }
95 };
96 
97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
98 /// have the same rank. For each outermost index in the slice:
99 /// begin end stride
100 /// [offset : offset+size*stride : stride]
101 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
102 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D
103 /// 3. the destination subvector is inserted back in the proper place
104 /// 3. InsertOp that is the reverse of 1.
106  : public OpRewritePattern<InsertStridedSliceOp> {
107 public:
109 
110  void initialize() {
111  // This pattern creates recursive InsertStridedSliceOp, but the recursion is
112  // bounded as the rank is strictly decreasing.
113  setHasBoundedRewriteRecursion();
114  }
115 
116  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
117  PatternRewriter &rewriter) const override {
118  auto srcType = op.getSourceVectorType();
119  auto dstType = op.getDestVectorType();
120 
121  if (op.offsets().getValue().empty())
122  return failure();
123 
124  int64_t srcRank = srcType.getRank();
125  int64_t dstRank = dstType.getRank();
126  assert(dstRank >= srcRank);
127  if (dstRank != srcRank)
128  return failure();
129 
130  if (srcType == dstType) {
131  rewriter.replaceOp(op, op.source());
132  return success();
133  }
134 
135  int64_t offset =
136  op.offsets().getValue().front().cast<IntegerAttr>().getInt();
137  int64_t size = srcType.getShape().front();
138  int64_t stride =
139  op.strides().getValue().front().cast<IntegerAttr>().getInt();
140 
141  auto loc = op.getLoc();
142  Value res = op.dest();
143 
144  if (srcRank == 1) {
145  int nSrc = srcType.getShape().front();
146  int nDest = dstType.getShape().front();
147  // 1. Scale source to destType so we can shufflevector them together.
148  SmallVector<int64_t> offsets(nDest, 0);
149  for (int64_t i = 0; i < nSrc; ++i)
150  offsets[i] = i;
151  Value scaledSource =
152  rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
153 
154  // 2. Create a mask where we take the value from scaledSource of dest
155  // depending on the offset.
156  offsets.clear();
157  for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
158  if (i < offset || i >= e || (i - offset) % stride != 0)
159  offsets.push_back(nDest + i);
160  else
161  offsets.push_back((i - offset) / stride);
162  }
163 
164  // 3. Replace with a ShuffleOp.
165  rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
166  offsets);
167 
168  return success();
169  }
170 
171  // For each slice of the source vector along the most major dimension.
172  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
173  off += stride, ++idx) {
174  // 1. extract the proper subvector (or element) from source
175  Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
176  if (extractedSource.getType().isa<VectorType>()) {
177  // 2. If we have a vector, extract the proper subvector from destination
178  // Otherwise we are at the element level and no need to recurse.
179  Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
180  // 3. Reduce the problem to lowering a new InsertStridedSlice op with
181  // smaller rank.
182  extractedSource = rewriter.create<InsertStridedSliceOp>(
183  loc, extractedSource, extractedDest,
184  getI64SubArray(op.offsets(), /* dropFront=*/1),
185  getI64SubArray(op.strides(), /* dropFront=*/1));
186  }
187  // 4. Insert the extractedSource into the res vector.
188  res = insertOne(rewriter, loc, extractedSource, res, off);
189  }
190 
191  rewriter.replaceOp(op, res);
192  return success();
193  }
194 };
195 
196 /// Progressive lowering of ExtractStridedSliceOp to either:
197 /// 1. single offset extract as a direct vector::ShuffleOp.
198 /// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
199 /// InsertOp/InsertElementOp for the n-D case.
201  : public OpRewritePattern<ExtractStridedSliceOp> {
202 public:
204 
205  void initialize() {
206  // This pattern creates recursive ExtractStridedSliceOp, but the recursion
207  // is bounded as the rank is strictly decreasing.
208  setHasBoundedRewriteRecursion();
209  }
210 
211  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
212  PatternRewriter &rewriter) const override {
213  auto dstType = op.getType();
214 
215  assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
216 
217  int64_t offset =
218  op.offsets().getValue().front().cast<IntegerAttr>().getInt();
219  int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
220  int64_t stride =
221  op.strides().getValue().front().cast<IntegerAttr>().getInt();
222 
223  auto loc = op.getLoc();
224  auto elemType = dstType.getElementType();
225  assert(elemType.isSignlessIntOrIndexOrFloat());
226 
227  // Single offset can be more efficiently shuffled.
228  if (op.offsets().getValue().size() == 1) {
229  SmallVector<int64_t, 4> offsets;
230  offsets.reserve(size);
231  for (int64_t off = offset, e = offset + size * stride; off < e;
232  off += stride)
233  offsets.push_back(off);
234  rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
235  op.vector(),
236  rewriter.getI64ArrayAttr(offsets));
237  return success();
238  }
239 
240  // Extract/insert on a lower ranked extract strided slice op.
241  Value zero = rewriter.create<arith::ConstantOp>(
242  loc, elemType, rewriter.getZeroAttr(elemType));
243  Value res = rewriter.create<SplatOp>(loc, dstType, zero);
244  for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
245  off += stride, ++idx) {
246  Value one = extractOne(rewriter, loc, op.vector(), off);
247  Value extracted = rewriter.create<ExtractStridedSliceOp>(
248  loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
249  getI64SubArray(op.sizes(), /* dropFront=*/1),
250  getI64SubArray(op.strides(), /* dropFront=*/1));
251  res = insertOne(rewriter, loc, extracted, res, idx);
252  }
253  rewriter.replaceOp(op, res);
254  return success();
255  }
256 };
257 
258 /// Populate the given list with patterns that convert from Vector to LLVM.
260  RewritePatternSet &patterns) {
264  patterns.getContext());
265 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
RewritePattern for InsertStridedSliceOp where source and destination vectors have different ranks...
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns)
Populate patterns with the following patterns.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset)
Progressive lowering of ExtractStridedSliceOp to either:
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override
RewritePattern for InsertStridedSliceOp where source and destination vectors have the same rank...
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
bool isa() const
Definition: Types.h:234
MLIRContext * getContext() const
Definition: PatternMatch.h:906
U cast() const
Definition: Types.h:250