MLIR  21.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #define DEBUG_TYPE "vector-shape-cast-lowering"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 /// Increments n-D `indices` by `step` starting from the innermost dimension.
31 static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
32  int step = 1) {
33  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34  assert(indices[dim] < vecType.getDimSize(dim) &&
35  "Indices are out of bound");
36  indices[dim] += step;
37  if (indices[dim] < vecType.getDimSize(dim))
38  break;
39 
40  indices[dim] = 0;
41  step = 1;
42  }
43 }
44 
45 namespace {
46 /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
47 /// vectors progressively. This iterates over the n-1 major dimensions of the
48 /// n-D vector and performs rewrites into:
49 /// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
50 class ShapeCastOpNDDownCastRewritePattern
51  : public OpRewritePattern<vector::ShapeCastOp> {
52 public:
54 
55  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
56  PatternRewriter &rewriter) const override {
57  auto sourceVectorType = op.getSourceVectorType();
58  auto resultVectorType = op.getResultVectorType();
59  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
60  return failure();
61 
62  int64_t srcRank = sourceVectorType.getRank();
63  int64_t resRank = resultVectorType.getRank();
64  if (srcRank < 2 || resRank != 1)
65  return failure();
66 
67  // Compute the number of 1-D vector elements involved in the reshape.
68  int64_t numElts = 1;
69  for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70  numElts *= sourceVectorType.getDimSize(dim);
71 
72  auto loc = op.getLoc();
73  SmallVector<int64_t> srcIdx(srcRank - 1, 0);
74  SmallVector<int64_t> resIdx(resRank, 0);
75  int64_t extractSize = sourceVectorType.getShape().back();
76  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
77 
78  // Compute the indices of each 1-D vector element of the source extraction
79  // and destination slice insertion and generate such instructions.
80  for (int64_t i = 0; i < numElts; ++i) {
81  if (i != 0) {
82  incIdx(srcIdx, sourceVectorType, /*step=*/1);
83  incIdx(resIdx, resultVectorType, /*step=*/extractSize);
84  }
85 
86  Value extract =
87  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
88  result = rewriter.create<vector::InsertStridedSliceOp>(
89  loc, extract, result,
90  /*offsets=*/resIdx, /*strides=*/1);
91  }
92 
93  rewriter.replaceOp(op, result);
94  return success();
95  }
96 };
97 
98 /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
99 /// vectors progressively. This iterates over the n-1 major dimension of the n-D
100 /// vector and performs rewrites into:
101 /// vector.extract_strided_slice from 1-D + vector.insert into n-D
102 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
103 class ShapeCastOpNDUpCastRewritePattern
104  : public OpRewritePattern<vector::ShapeCastOp> {
105 public:
107 
108  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
109  PatternRewriter &rewriter) const override {
110  auto sourceVectorType = op.getSourceVectorType();
111  auto resultVectorType = op.getResultVectorType();
112  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
113  return failure();
114 
115  int64_t srcRank = sourceVectorType.getRank();
116  int64_t resRank = resultVectorType.getRank();
117  if (srcRank != 1 || resRank < 2)
118  return failure();
119 
120  // Compute the number of 1-D vector elements involved in the reshape.
121  int64_t numElts = 1;
122  for (int64_t dim = 0; dim < resRank - 1; ++dim)
123  numElts *= resultVectorType.getDimSize(dim);
124 
125  // Compute the indices of each 1-D vector element of the source slice
126  // extraction and destination insertion and generate such instructions.
127  auto loc = op.getLoc();
128  SmallVector<int64_t> srcIdx(srcRank, 0);
129  SmallVector<int64_t> resIdx(resRank - 1, 0);
130  int64_t extractSize = resultVectorType.getShape().back();
131  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
132  for (int64_t i = 0; i < numElts; ++i) {
133  if (i != 0) {
134  incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
135  incIdx(resIdx, resultVectorType, /*step=*/1);
136  }
137 
138  Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
139  loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
140  /*strides=*/1);
141  result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
142  }
143  rewriter.replaceOp(op, result);
144  return success();
145  }
146 };
147 
148 // We typically should not lower general shape cast operations into data
149 // movement instructions, since the assumption is that these casts are
150 // optimized away during progressive lowering. For completeness, however,
151 // we fall back to a reference implementation that moves all elements
152 // into the right place if we get here.
153 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
154 public:
156 
157  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
158  PatternRewriter &rewriter) const override {
159  Location loc = op.getLoc();
160  auto sourceVectorType = op.getSourceVectorType();
161  auto resultVectorType = op.getResultVectorType();
162 
163  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
164  return failure();
165 
166  // Special case for n-D / 1-D lowerings with better implementations.
167  int64_t srcRank = sourceVectorType.getRank();
168  int64_t resRank = resultVectorType.getRank();
169  if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
170  return failure();
171 
172  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
173  // extract/insert chains.
174  int64_t numElts = 1;
175  for (int64_t r = 0; r < srcRank; r++)
176  numElts *= sourceVectorType.getDimSize(r);
177  // Replace with data movement operations:
178  // x[0,0,0] = y[0,0]
179  // x[0,0,1] = y[0,1]
180  // x[0,1,0] = y[0,2]
181  // etc., incrementing the two index vectors "row-major"
182  // within the source and result shape.
183  SmallVector<int64_t> srcIdx(srcRank, 0);
184  SmallVector<int64_t> resIdx(resRank, 0);
185  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
186  for (int64_t i = 0; i < numElts; i++) {
187  if (i != 0) {
188  incIdx(srcIdx, sourceVectorType);
189  incIdx(resIdx, resultVectorType);
190  }
191 
192  Value extract;
193  if (srcRank == 0) {
194  // 0-D vector special case
195  assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
196  extract = rewriter.create<vector::ExtractElementOp>(
197  loc, op.getSourceVectorType().getElementType(), op.getSource());
198  } else {
199  extract =
200  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
201  }
202 
203  if (resRank == 0) {
204  // 0-D vector special case
205  assert(resIdx.empty() && "Unexpected indices for 0-D vector");
206  result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
207  } else {
208  result =
209  rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
210  }
211  }
212  rewriter.replaceOp(op, result);
213  return success();
214  }
215 };
216 
217 /// A shape_cast lowering for scalable vectors with a single trailing scalable
218 /// dimension. This is similar to the general shape_cast lowering but makes use
219 /// of vector.scalable.insert and vector.scalable.extract to move elements a
220 /// subvector at a time.
221 ///
222 /// E.g.:
223 /// ```
224 /// // Flatten scalable vector
225 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
226 /// ```
227 /// is rewritten to:
228 /// ```
229 /// // Flatten scalable vector
230 /// %c = arith.constant dense<0> : vector<[8]xi32>
231 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
232 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
233 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
234 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
235 /// ```
236 /// or:
237 /// ```
238 /// // Un-flatten scalable vector
239 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
240 /// ```
241 /// is rewritten to:
242 /// ```
243 /// // Un-flatten scalable vector
244 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
245 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
246 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
247 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
248 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
249 /// ```
250 class ScalableShapeCastOpRewritePattern
251  : public OpRewritePattern<vector::ShapeCastOp> {
252 public:
254 
255  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
256  PatternRewriter &rewriter) const override {
257 
258  Location loc = op.getLoc();
259  auto sourceVectorType = op.getSourceVectorType();
260  auto resultVectorType = op.getResultVectorType();
261  auto srcRank = sourceVectorType.getRank();
262  auto resRank = resultVectorType.getRank();
263 
264  // This can only lower shape_casts where both the source and result types
265  // have a single trailing scalable dimension. This is because there are no
266  // legal representation of other scalable types in LLVM (and likely won't be
267  // soon). There are also (currently) no operations that can index or extract
268  // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
269  if (!isTrailingDimScalable(sourceVectorType) ||
270  !isTrailingDimScalable(resultVectorType)) {
271  return failure();
272  }
273 
274  // The sizes of the trailing dimension of the source and result vectors, the
275  // size of subvector to move, and the number of elements in the vectors.
276  // These are "min" sizes as they are the size when vscale == 1.
277  auto minSourceTrailingSize = sourceVectorType.getShape().back();
278  auto minResultTrailingSize = resultVectorType.getShape().back();
279  auto minExtractionSize =
280  std::min(minSourceTrailingSize, minResultTrailingSize);
281  int64_t minNumElts = 1;
282  for (auto size : sourceVectorType.getShape())
283  minNumElts *= size;
284 
285  // The subvector type to move from the source to the result. Note that this
286  // is a scalable vector. This rewrite will generate code in terms of the
287  // "min" size (vscale == 1 case), that scales to any vscale.
288  auto extractionVectorType = VectorType::get(
289  {minExtractionSize}, sourceVectorType.getElementType(), {true});
290 
291  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
292  SmallVector<int64_t> srcIdx(srcRank, 0);
293  SmallVector<int64_t> resIdx(resRank, 0);
294 
295  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
296  // once D150000 lands.
297  Value currentResultScalableVector;
298  Value currentSourceScalableVector;
299  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
300  // 1. Extract a scalable subvector from the source vector.
301  if (!currentSourceScalableVector) {
302  if (srcRank != 1) {
303  currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
304  loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
305  } else {
306  currentSourceScalableVector = op.getSource();
307  }
308  }
309  Value sourceSubVector = currentSourceScalableVector;
310  if (minExtractionSize < minSourceTrailingSize) {
311  sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
312  loc, extractionVectorType, sourceSubVector, srcIdx.back());
313  }
314 
315  // 2. Insert the scalable subvector into the result vector.
316  if (!currentResultScalableVector) {
317  if (minExtractionSize == minResultTrailingSize) {
318  currentResultScalableVector = sourceSubVector;
319  } else if (resRank != 1) {
320  currentResultScalableVector = rewriter.create<vector::ExtractOp>(
321  loc, result, llvm::ArrayRef(resIdx).drop_back());
322  } else {
323  currentResultScalableVector = result;
324  }
325  }
326  if (minExtractionSize < minResultTrailingSize) {
327  currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
328  loc, sourceSubVector, currentResultScalableVector, resIdx.back());
329  }
330 
331  // 3. Update the source and result scalable vectors if needed.
332  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
333  currentResultScalableVector != result) {
334  // Finished row of result. Insert complete scalable vector into result
335  // (n-D) vector.
336  result = rewriter.create<vector::InsertOp>(
337  loc, currentResultScalableVector, result,
338  llvm::ArrayRef(resIdx).drop_back());
339  currentResultScalableVector = {};
340  }
341  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
342  // Finished row of source.
343  currentSourceScalableVector = {};
344  }
345 
346  // 4. Increment the insert/extract indices, stepping by minExtractionSize
347  // for the trailing dimensions.
348  incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
349  incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
350  }
351 
352  rewriter.replaceOp(op, result);
353  return success();
354  }
355 
356  static bool isTrailingDimScalable(VectorType type) {
357  return type.getRank() >= 1 && type.getScalableDims().back() &&
358  !llvm::is_contained(type.getScalableDims().drop_back(), true);
359  }
360 };
361 
362 } // namespace
363 
366  patterns.add<ShapeCastOpNDDownCastRewritePattern,
367  ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
368  ScalableShapeCastOpRewritePattern>(patterns.getContext(),
369  benefit);
370 }
static void incIdx(SmallVectorImpl< int64_t > &indices, VectorType vecType, int step=1)
Increments n-D indices by step starting from the innermost dimension.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362