MLIR  22.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include <numeric>
25 
26 #define DEBUG_TYPE "vector-shape-cast-lowering"
27 
28 using namespace mlir;
29 
30 /// Perform the inplace update
31 /// rhs <- lhs + rhs
32 ///
33 /// where `rhs` is a number expressed in mixed base `base` with most signficant
34 /// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
35 /// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
36 ///
37 /// Some examples where `base` is {5,3,2}:
38 /// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
39 /// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
40 /// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
41 ///
42 /// Invalid:
43 /// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
44 ///
45 /// Overflows not handled correctly:
46 /// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
47 static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
49 
50  // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
51  for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
52  int64_t dimBase = base[dim];
53  assert(rhs[dim] < dimBase && "rhs not in base");
54 
55  int64_t incremented = rhs[dim] + lhs;
56 
57  // If the incremented value excedes the dimension base, we must spill to the
58  // next most significant dimension and repeat (we might need to spill to
59  // more significant dimensions multiple times).
60  lhs = incremented / dimBase;
61  rhs[dim] = incremented % dimBase;
62  if (lhs == 0)
63  break;
64  }
65 }
66 
67 namespace {
68 
69 /// shape_cast is converted to a sequence of extract, extract_strided_slice,
70 /// insert_strided_slice, and insert operations. The running example will be:
71 ///
72 /// %0 = vector.shape_cast %arg0 :
73 /// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
74 ///
75 /// In this example the source and result shapes share a common suffix of 7x11.
76 /// This means we can always decompose the shape_cast into extract, insert, and
77 /// their strided equivalents, on vectors with shape suffix 7x11.
78 ///
79 /// The greatest common divisor (gcd) of the first dimension preceding the
80 /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
81 /// on vectors with shapes that are `multiples` of (what we define as) the
82 /// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
83 ///
84 /// vector<2x2x3x4x7x11xi8> to
85 /// vector<8x6x7x11xi8>
86 /// | ||||
87 /// | ++++------------> common suffix of 7x11
88 /// +-----------------> gcd(4,6) is 2 | |
89 /// | | |
90 /// v v v
91 /// atomic shape <----- 2x7x11
92 ///
93 ///
94 ///
95 /// The decomposition implemented in this pattern consists of a sequence of
96 /// repeated steps:
97 ///
98 /// (1) Extract vectors from the suffix of the source.
99 /// In our example this is 2x2x3x4x7x11 -> 4x7x11.
100 ///
101 /// (2) Do extract_strided_slice down to the atomic shape.
102 /// In our example this is 4x7x11 -> 2x7x11.
103 ///
104 /// (3) Do insert_strided_slice to the suffix of the result.
105 /// In our example this is 2x7x11 -> 6x7x11.
106 ///
107 /// (4) insert these vectors into the result vector.
108 /// In our example this is 6x7x11 -> 8x6x7x11.
109 ///
110 /// These steps occur with different periods. In this example
111 /// (1) occurs 12 times,
112 /// (2) and (3) occur 24 times, and
113 /// (4) occurs 8 times.
114 ///
115 /// Two special cases are handled independently in this pattern
116 /// (i) A shape_cast that just does leading 1 insertion/removal
117 /// (ii) A shape_cast where the gcd is 1.
118 ///
119 /// These 2 cases can have more compact IR generated by not using the generic
120 /// algorithm described above.
121 ///
122 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
123 
124  // Case (i) of description.
125  // Assumes source and result shapes are identical up to some leading ones.
126  static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
127  PatternRewriter &rewriter) {
128 
129  const Location loc = shapeCast.getLoc();
130  const VectorType sourceType = shapeCast.getSourceVectorType();
131  const VectorType resultType = shapeCast.getResultVectorType();
132 
133  const int64_t sourceRank = sourceType.getRank();
134  const int64_t resultRank = resultType.getRank();
135  const int64_t delta = sourceRank - resultRank;
136  const int64_t sourceLeading = delta > 0 ? delta : 0;
137  const int64_t resultLeading = delta > 0 ? 0 : -delta;
138 
139  const Value source = shapeCast.getSource();
140  const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
141  const Value extracted = vector::ExtractOp::create(
142  rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0));
143  const Value result =
144  vector::InsertOp::create(rewriter, loc, extracted, poison,
145  SmallVector<int64_t>(resultLeading, 0));
146 
147  rewriter.replaceOp(shapeCast, result);
148  return success();
149  }
150 
151  // Case (ii) of description.
152  // Assumes a shape_cast where the suffix shape of the source starting at
153  // `sourceDim` and the suffix shape of the result starting at `resultDim` are
154  // identical.
155  static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
156  int64_t sourceDim,
157  int64_t resultDim,
158  PatternRewriter &rewriter) {
159 
160  const Location loc = shapeCast.getLoc();
161 
162  const Value source = shapeCast.getSource();
163  const ArrayRef<int64_t> sourceShape =
164  shapeCast.getSourceVectorType().getShape();
165 
166  const VectorType resultType = shapeCast.getResultVectorType();
167  const ArrayRef<int64_t> resultShape = resultType.getShape();
168 
169  const int64_t nSlices =
170  std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
171  std::multiplies<int64_t>());
172 
173  SmallVector<int64_t> extractIndex(sourceDim, 0);
174  SmallVector<int64_t> insertIndex(resultDim, 0);
175  Value result = ub::PoisonOp::create(rewriter, loc, resultType);
176 
177  for (int i = 0; i < nSlices; ++i) {
178  Value extracted =
179  vector::ExtractOp::create(rewriter, loc, source, extractIndex);
180 
181  result = vector::InsertOp::create(rewriter, loc, extracted, result,
182  insertIndex);
183 
184  inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
185  inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
186  }
187  rewriter.replaceOp(shapeCast, result);
188  return success();
189  }
190 
191 public:
193 
194  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
195  PatternRewriter &rewriter) const override {
196  Location loc = op.getLoc();
197  VectorType sourceType = op.getSourceVectorType();
198  VectorType resultType = op.getResultVectorType();
199 
200  if (sourceType.isScalable() || resultType.isScalable())
201  return rewriter.notifyMatchFailure(
202  op,
203  "shape_cast where vectors are scalable not handled by this pattern");
204 
205  const ArrayRef<int64_t> sourceShape = sourceType.getShape();
206  const ArrayRef<int64_t> resultShape = resultType.getShape();
207  const int64_t sourceRank = sourceType.getRank();
208  const int64_t resultRank = resultType.getRank();
209  const int64_t numElms = sourceType.getNumElements();
210  const Value source = op.getSource();
211 
212  // Set the first dimension (starting at the end) in the source and result
213  // respectively where the dimension sizes differ. Using the running example:
214  //
215  // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
216  // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
217  // ^ ^
218  // | |
219  // sourceSuffixStartDim is 3 |
220  // |
221  // resultSuffixStartDim is 1
222  int64_t sourceSuffixStartDim = sourceRank - 1;
223  int64_t resultSuffixStartDim = resultRank - 1;
224  while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
225  (sourceType.getDimSize(sourceSuffixStartDim) ==
226  resultType.getDimSize(resultSuffixStartDim))) {
227  --sourceSuffixStartDim;
228  --resultSuffixStartDim;
229  }
230 
231  // This is the case (i) where there are just some leading ones to contend
232  // with in the source or result. It can be handled with a single
233  // extract/insert pair.
234  if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
235  return leadingOnesLowering(op, rewriter);
236 
237  const int64_t sourceSuffixStartDimSize =
238  sourceType.getDimSize(sourceSuffixStartDim);
239  const int64_t resultSuffixStartDimSize =
240  resultType.getDimSize(resultSuffixStartDim);
241  const int64_t greatestCommonDivisor =
242  std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
243  const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
244  const size_t extractPeriod =
245  sourceSuffixStartDimSize / greatestCommonDivisor;
246  const size_t insertPeriod =
247  resultSuffixStartDimSize / greatestCommonDivisor;
248 
249  SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
250  sourceShape.end());
251  atomicShape[0] = greatestCommonDivisor;
252 
253  const int64_t numAtomicElms = std::accumulate(
254  atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
255  const size_t nAtomicSlices = numElms / numAtomicElms;
256 
257  // This is the case (ii) where the strided dimension size is 1. More compact
258  // IR is generated in this case if we just extract and insert the elements
259  // directly. In other words, we don't use extract_strided_slice and
260  // insert_strided_slice.
261  if (greatestCommonDivisor == 1)
262  return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
263  resultSuffixStartDim + 1, rewriter);
264 
265  // The insert_strided_slice result's type
266  const ArrayRef<int64_t> insertStridedShape =
267  resultShape.drop_front(resultSuffixStartDim);
268  const VectorType insertStridedType =
269  VectorType::get(insertStridedShape, resultType.getElementType());
270 
271  SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
272  SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
273  SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
274  SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
275  const SmallVector<int64_t> sizes(stridedSliceRank, 1);
276 
277  Value extracted = {};
278  Value extractedStrided = {};
279  Value insertedSlice = {};
280  Value result = ub::PoisonOp::create(rewriter, loc, resultType);
281  const Value partResult =
282  ub::PoisonOp::create(rewriter, loc, insertStridedType);
283 
284  for (size_t i = 0; i < nAtomicSlices; ++i) {
285 
286  const size_t extractStridedPhase = i % extractPeriod;
287  const size_t insertStridedPhase = i % insertPeriod;
288 
289  // vector.extract
290  if (extractStridedPhase == 0) {
291  extracted =
292  vector::ExtractOp::create(rewriter, loc, source, extractIndex);
293  inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
294  extractIndex);
295  }
296 
297  // vector.extract_strided_slice
298  extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
299  extractedStrided = vector::ExtractStridedSliceOp::create(
300  rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
301 
302  // vector.insert_strided_slice
303  if (insertStridedPhase == 0) {
304  insertedSlice = partResult;
305  }
306  insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
307  insertedSlice = vector::InsertStridedSliceOp::create(
308  rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
309 
310  // vector.insert
311  if (insertStridedPhase + 1 == insertPeriod) {
312  result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
313  insertIndex);
314  inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
315  insertIndex);
316  }
317  }
318  rewriter.replaceOp(op, result);
319  return success();
320  }
321 };
322 
323 /// A shape_cast lowering for scalable vectors with a single trailing scalable
324 /// dimension. This is similar to the general shape_cast lowering but makes use
325 /// of vector.scalable.insert and vector.scalable.extract to move elements a
326 /// subvector at a time.
327 ///
328 /// E.g.:
329 /// ```
330 /// // Flatten scalable vector
331 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
332 /// ```
333 /// is rewritten to:
334 /// ```
335 /// // Flatten scalable vector
336 /// %c = arith.constant dense<0> : vector<[8]xi32>
337 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
338 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
339 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
340 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
341 /// ```
342 /// or:
343 /// ```
344 /// // Un-flatten scalable vector
345 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
346 /// ```
347 /// is rewritten to:
348 /// ```
349 /// // Un-flatten scalable vector
350 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
351 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
352 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
353 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
354 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
355 /// ```
356 class ScalableShapeCastOpRewritePattern
357  : public OpRewritePattern<vector::ShapeCastOp> {
358 public:
360 
361  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
362  PatternRewriter &rewriter) const override {
363 
364  Location loc = op.getLoc();
365  auto sourceVectorType = op.getSourceVectorType();
366  auto resultVectorType = op.getResultVectorType();
367  auto srcRank = sourceVectorType.getRank();
368  auto resRank = resultVectorType.getRank();
369 
370  // This can only lower shape_casts where both the source and result types
371  // have a single trailing scalable dimension. This is because there are no
372  // legal representation of other scalable types in LLVM (and likely won't be
373  // soon). There are also (currently) no operations that can index or extract
374  // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
375  if (!isTrailingDimScalable(sourceVectorType) ||
376  !isTrailingDimScalable(resultVectorType)) {
377  return rewriter.notifyMatchFailure(
378  op, "trailing dims are not scalable, not handled by this pattern");
379  }
380 
381  // The sizes of the trailing dimension of the source and result vectors, the
382  // size of subvector to move, and the number of elements in the vectors.
383  // These are "min" sizes as they are the size when vscale == 1.
384  auto minSourceTrailingSize = sourceVectorType.getShape().back();
385  auto minResultTrailingSize = resultVectorType.getShape().back();
386  auto minExtractionSize =
387  std::min(minSourceTrailingSize, minResultTrailingSize);
388  int64_t minNumElts = 1;
389  for (auto size : sourceVectorType.getShape())
390  minNumElts *= size;
391 
392  // The subvector type to move from the source to the result. Note that this
393  // is a scalable vector. This rewrite will generate code in terms of the
394  // "min" size (vscale == 1 case), that scales to any vscale.
395  auto extractionVectorType = VectorType::get(
396  {minExtractionSize}, sourceVectorType.getElementType(), {true});
397 
398  Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
399  SmallVector<int64_t> srcIdx(srcRank, 0);
400  SmallVector<int64_t> resIdx(resRank, 0);
401 
402  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
403  // once D150000 lands.
404  Value currentResultScalableVector;
405  Value currentSourceScalableVector;
406  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
407  // 1. Extract a scalable subvector from the source vector.
408  if (!currentSourceScalableVector) {
409  if (srcRank != 1) {
410  currentSourceScalableVector =
411  vector::ExtractOp::create(rewriter, loc, op.getSource(),
412  llvm::ArrayRef(srcIdx).drop_back());
413  } else {
414  currentSourceScalableVector = op.getSource();
415  }
416  }
417  Value sourceSubVector = currentSourceScalableVector;
418  if (minExtractionSize < minSourceTrailingSize) {
419  sourceSubVector = vector::ScalableExtractOp::create(
420  rewriter, loc, extractionVectorType, sourceSubVector,
421  srcIdx.back());
422  }
423 
424  // 2. Insert the scalable subvector into the result vector.
425  if (!currentResultScalableVector) {
426  if (minExtractionSize == minResultTrailingSize) {
427  currentResultScalableVector = sourceSubVector;
428  } else if (resRank != 1) {
429  currentResultScalableVector = vector::ExtractOp::create(
430  rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back());
431  } else {
432  currentResultScalableVector = result;
433  }
434  }
435  if (minExtractionSize < minResultTrailingSize) {
436  currentResultScalableVector = vector::ScalableInsertOp::create(
437  rewriter, loc, sourceSubVector, currentResultScalableVector,
438  resIdx.back());
439  }
440 
441  // 3. Update the source and result scalable vectors if needed.
442  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
443  currentResultScalableVector != result) {
444  // Finished row of result. Insert complete scalable vector into result
445  // (n-D) vector.
446  result = vector::InsertOp::create(rewriter, loc,
447  currentResultScalableVector, result,
448  llvm::ArrayRef(resIdx).drop_back());
449  currentResultScalableVector = {};
450  }
451  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
452  // Finished row of source.
453  currentSourceScalableVector = {};
454  }
455 
456  // 4. Increment the insert/extract indices, stepping by minExtractionSize
457  // for the trailing dimensions.
458  inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
459  inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
460  }
461 
462  rewriter.replaceOp(op, result);
463  return success();
464  }
465 
466  static bool isTrailingDimScalable(VectorType type) {
467  return type.getRank() >= 1 && type.getScalableDims().back() &&
468  !llvm::is_contained(type.getScalableDims().drop_back(), true);
469  }
470 };
471 
472 } // namespace
473 
476  patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
477  patterns.getContext(), benefit);
478 }
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319