MLIR  22.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "vector-shape-cast-lowering"
28 
29 using namespace mlir;
30 
31 /// Perform the inplace update
32 /// rhs <- lhs + rhs
33 ///
34 /// where `rhs` is a number expressed in mixed base `base` with most signficant
35 /// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
36 /// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
37 ///
38 /// Some examples where `base` is {5,3,2}:
39 /// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
40 /// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
41 /// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
42 ///
43 /// Invalid:
44 /// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
45 ///
46 /// Overflows not handled correctly:
47 /// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
48 static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
50 
51  // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
52  for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
53  int64_t dimBase = base[dim];
54  assert(rhs[dim] < dimBase && "rhs not in base");
55 
56  int64_t incremented = rhs[dim] + lhs;
57 
58  // If the incremented value excedes the dimension base, we must spill to the
59  // next most significant dimension and repeat (we might need to spill to
60  // more significant dimensions multiple times).
61  lhs = incremented / dimBase;
62  rhs[dim] = incremented % dimBase;
63  if (lhs == 0)
64  break;
65  }
66 }
67 
68 namespace {
69 
70 /// shape_cast is converted to a sequence of extract, extract_strided_slice,
71 /// insert_strided_slice, and insert operations. The running example will be:
72 ///
73 /// %0 = vector.shape_cast %arg0 :
74 /// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
75 ///
76 /// In this example the source and result shapes share a common suffix of 7x11.
77 /// This means we can always decompose the shape_cast into extract, insert, and
78 /// their strided equivalents, on vectors with shape suffix 7x11.
79 ///
80 /// The greatest common divisor (gcd) of the first dimension preceding the
81 /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
82 /// on vectors with shapes that are `multiples` of (what we define as) the
83 /// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
84 ///
85 /// vector<2x2x3x4x7x11xi8> to
86 /// vector<8x6x7x11xi8>
87 /// | ||||
88 /// | ++++------------> common suffix of 7x11
89 /// +-----------------> gcd(4,6) is 2 | |
90 /// | | |
91 /// v v v
92 /// atomic shape <----- 2x7x11
93 ///
94 ///
95 ///
96 /// The decomposition implemented in this pattern consists of a sequence of
97 /// repeated steps:
98 ///
99 /// (1) Extract vectors from the suffix of the source.
100 /// In our example this is 2x2x3x4x7x11 -> 4x7x11.
101 ///
102 /// (2) Do extract_strided_slice down to the atomic shape.
103 /// In our example this is 4x7x11 -> 2x7x11.
104 ///
105 /// (3) Do insert_strided_slice to the suffix of the result.
106 /// In our example this is 2x7x11 -> 6x7x11.
107 ///
108 /// (4) insert these vectors into the result vector.
109 /// In our example this is 6x7x11 -> 8x6x7x11.
110 ///
111 /// These steps occur with different periods. In this example
112 /// (1) occurs 12 times,
113 /// (2) and (3) occur 24 times, and
114 /// (4) occurs 8 times.
115 ///
116 /// Two special cases are handled independently in this pattern
117 /// (i) A shape_cast that just does leading 1 insertion/removal
118 /// (ii) A shape_cast where the gcd is 1.
119 ///
120 /// These 2 cases can have more compact IR generated by not using the generic
121 /// algorithm described above.
122 ///
123 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
124 
125  // Case (i) of description.
126  // Assumes source and result shapes are identical up to some leading ones.
127  static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
128  PatternRewriter &rewriter) {
129 
130  const Location loc = shapeCast.getLoc();
131  const VectorType sourceType = shapeCast.getSourceVectorType();
132  const VectorType resultType = shapeCast.getResultVectorType();
133 
134  const int64_t sourceRank = sourceType.getRank();
135  const int64_t resultRank = resultType.getRank();
136  const int64_t delta = sourceRank - resultRank;
137  const int64_t sourceLeading = delta > 0 ? delta : 0;
138  const int64_t resultLeading = delta > 0 ? 0 : -delta;
139 
140  const Value source = shapeCast.getSource();
141  const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
142  const Value extracted = vector::ExtractOp::create(
143  rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0));
144  const Value result =
145  vector::InsertOp::create(rewriter, loc, extracted, poison,
146  SmallVector<int64_t>(resultLeading, 0));
147 
148  rewriter.replaceOp(shapeCast, result);
149  return success();
150  }
151 
152  // Case (ii) of description.
153  // Assumes a shape_cast where the suffix shape of the source starting at
154  // `sourceDim` and the suffix shape of the result starting at `resultDim` are
155  // identical.
156  static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
157  int64_t sourceDim,
158  int64_t resultDim,
159  PatternRewriter &rewriter) {
160 
161  const Location loc = shapeCast.getLoc();
162 
163  const Value source = shapeCast.getSource();
164  const ArrayRef<int64_t> sourceShape =
165  shapeCast.getSourceVectorType().getShape();
166 
167  const VectorType resultType = shapeCast.getResultVectorType();
168  const ArrayRef<int64_t> resultShape = resultType.getShape();
169 
170  const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim));
171  SmallVector<int64_t> extractIndex(sourceDim, 0);
172  SmallVector<int64_t> insertIndex(resultDim, 0);
173  Value result = ub::PoisonOp::create(rewriter, loc, resultType);
174 
175  for (int i = 0; i < nSlices; ++i) {
176  Value extracted =
177  vector::ExtractOp::create(rewriter, loc, source, extractIndex);
178 
179  result = vector::InsertOp::create(rewriter, loc, extracted, result,
180  insertIndex);
181 
182  inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
183  inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
184  }
185  rewriter.replaceOp(shapeCast, result);
186  return success();
187  }
188 
189 public:
190  using Base::Base;
191 
192  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
193  PatternRewriter &rewriter) const override {
194  Location loc = op.getLoc();
195  VectorType sourceType = op.getSourceVectorType();
196  VectorType resultType = op.getResultVectorType();
197 
198  if (sourceType.isScalable() || resultType.isScalable())
199  return rewriter.notifyMatchFailure(
200  op,
201  "shape_cast where vectors are scalable not handled by this pattern");
202 
203  const ArrayRef<int64_t> sourceShape = sourceType.getShape();
204  const ArrayRef<int64_t> resultShape = resultType.getShape();
205  const int64_t sourceRank = sourceType.getRank();
206  const int64_t resultRank = resultType.getRank();
207  const int64_t numElms = sourceType.getNumElements();
208  const Value source = op.getSource();
209 
210  // Set the first dimension (starting at the end) in the source and result
211  // respectively where the dimension sizes differ. Using the running example:
212  //
213  // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
214  // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
215  // ^ ^
216  // | |
217  // sourceSuffixStartDim is 3 |
218  // |
219  // resultSuffixStartDim is 1
220  int64_t sourceSuffixStartDim = sourceRank - 1;
221  int64_t resultSuffixStartDim = resultRank - 1;
222  while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
223  (sourceType.getDimSize(sourceSuffixStartDim) ==
224  resultType.getDimSize(resultSuffixStartDim))) {
225  --sourceSuffixStartDim;
226  --resultSuffixStartDim;
227  }
228 
229  // This is the case (i) where there are just some leading ones to contend
230  // with in the source or result. It can be handled with a single
231  // extract/insert pair.
232  if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
233  return leadingOnesLowering(op, rewriter);
234 
235  const int64_t sourceSuffixStartDimSize =
236  sourceType.getDimSize(sourceSuffixStartDim);
237  const int64_t resultSuffixStartDimSize =
238  resultType.getDimSize(resultSuffixStartDim);
239  const int64_t greatestCommonDivisor =
240  std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
241  const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
242  const size_t extractPeriod =
243  sourceSuffixStartDimSize / greatestCommonDivisor;
244  const size_t insertPeriod =
245  resultSuffixStartDimSize / greatestCommonDivisor;
246 
247  SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
248  sourceShape.end());
249  atomicShape[0] = greatestCommonDivisor;
250 
251  const int64_t numAtomicElms = std::accumulate(
252  atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
253  const size_t nAtomicSlices = numElms / numAtomicElms;
254 
255  // This is the case (ii) where the strided dimension size is 1. More compact
256  // IR is generated in this case if we just extract and insert the elements
257  // directly. In other words, we don't use extract_strided_slice and
258  // insert_strided_slice.
259  if (greatestCommonDivisor == 1)
260  return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
261  resultSuffixStartDim + 1, rewriter);
262 
263  // The insert_strided_slice result's type
264  const ArrayRef<int64_t> insertStridedShape =
265  resultShape.drop_front(resultSuffixStartDim);
266  const VectorType insertStridedType =
267  VectorType::get(insertStridedShape, resultType.getElementType());
268 
269  SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
270  SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
271  SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
272  SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
273  const SmallVector<int64_t> sizes(stridedSliceRank, 1);
274 
275  Value extracted = {};
276  Value extractedStrided = {};
277  Value insertedSlice = {};
278  Value result = ub::PoisonOp::create(rewriter, loc, resultType);
279  const Value partResult =
280  ub::PoisonOp::create(rewriter, loc, insertStridedType);
281 
282  for (size_t i = 0; i < nAtomicSlices; ++i) {
283 
284  const size_t extractStridedPhase = i % extractPeriod;
285  const size_t insertStridedPhase = i % insertPeriod;
286 
287  // vector.extract
288  if (extractStridedPhase == 0) {
289  extracted =
290  vector::ExtractOp::create(rewriter, loc, source, extractIndex);
291  inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
292  extractIndex);
293  }
294 
295  // vector.extract_strided_slice
296  extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
297  extractedStrided = vector::ExtractStridedSliceOp::create(
298  rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
299 
300  // vector.insert_strided_slice
301  if (insertStridedPhase == 0) {
302  insertedSlice = partResult;
303  }
304  insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
305  insertedSlice = vector::InsertStridedSliceOp::create(
306  rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
307 
308  // vector.insert
309  if (insertStridedPhase + 1 == insertPeriod) {
310  result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
311  insertIndex);
312  inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
313  insertIndex);
314  }
315  }
316  rewriter.replaceOp(op, result);
317  return success();
318  }
319 };
320 
321 /// A shape_cast lowering for scalable vectors with a single trailing scalable
322 /// dimension. This is similar to the general shape_cast lowering but makes use
323 /// of vector.scalable.insert and vector.scalable.extract to move elements a
324 /// subvector at a time.
325 ///
326 /// E.g.:
327 /// ```
328 /// // Flatten scalable vector
329 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
330 /// ```
331 /// is rewritten to:
332 /// ```
333 /// // Flatten scalable vector
334 /// %c = arith.constant dense<0> : vector<[8]xi32>
335 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
336 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
337 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
338 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
339 /// ```
340 /// or:
341 /// ```
342 /// // Un-flatten scalable vector
343 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
344 /// ```
345 /// is rewritten to:
346 /// ```
347 /// // Un-flatten scalable vector
348 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
349 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
350 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
351 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
352 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
353 /// ```
354 class ScalableShapeCastOpRewritePattern
355  : public OpRewritePattern<vector::ShapeCastOp> {
356 public:
357  using Base::Base;
358 
359  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
360  PatternRewriter &rewriter) const override {
361 
362  Location loc = op.getLoc();
363  auto sourceVectorType = op.getSourceVectorType();
364  auto resultVectorType = op.getResultVectorType();
365  auto srcRank = sourceVectorType.getRank();
366  auto resRank = resultVectorType.getRank();
367 
368  // This can only lower shape_casts where both the source and result types
369  // have a single trailing scalable dimension. This is because there are no
370  // legal representation of other scalable types in LLVM (and likely won't be
371  // soon). There are also (currently) no operations that can index or extract
372  // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
373  if (!isTrailingDimScalable(sourceVectorType) ||
374  !isTrailingDimScalable(resultVectorType)) {
375  return rewriter.notifyMatchFailure(
376  op, "trailing dims are not scalable, not handled by this pattern");
377  }
378 
379  // The sizes of the trailing dimension of the source and result vectors, the
380  // size of subvector to move, and the number of elements in the vectors.
381  // These are "min" sizes as they are the size when vscale == 1.
382  auto minSourceTrailingSize = sourceVectorType.getShape().back();
383  auto minResultTrailingSize = resultVectorType.getShape().back();
384  auto minExtractionSize =
385  std::min(minSourceTrailingSize, minResultTrailingSize);
386  int64_t minNumElts = 1;
387  for (auto size : sourceVectorType.getShape())
388  minNumElts *= size;
389 
390  // The subvector type to move from the source to the result. Note that this
391  // is a scalable vector. This rewrite will generate code in terms of the
392  // "min" size (vscale == 1 case), that scales to any vscale.
393  auto extractionVectorType = VectorType::get(
394  {minExtractionSize}, sourceVectorType.getElementType(), {true});
395 
396  Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
397  SmallVector<int64_t> srcIdx(srcRank, 0);
398  SmallVector<int64_t> resIdx(resRank, 0);
399 
400  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
401  // once D150000 lands.
402  Value currentResultScalableVector;
403  Value currentSourceScalableVector;
404  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
405  // 1. Extract a scalable subvector from the source vector.
406  if (!currentSourceScalableVector) {
407  if (srcRank != 1) {
408  currentSourceScalableVector =
409  vector::ExtractOp::create(rewriter, loc, op.getSource(),
410  llvm::ArrayRef(srcIdx).drop_back());
411  } else {
412  currentSourceScalableVector = op.getSource();
413  }
414  }
415  Value sourceSubVector = currentSourceScalableVector;
416  if (minExtractionSize < minSourceTrailingSize) {
417  sourceSubVector = vector::ScalableExtractOp::create(
418  rewriter, loc, extractionVectorType, sourceSubVector,
419  srcIdx.back());
420  }
421 
422  // 2. Insert the scalable subvector into the result vector.
423  if (!currentResultScalableVector) {
424  if (minExtractionSize == minResultTrailingSize) {
425  currentResultScalableVector = sourceSubVector;
426  } else if (resRank != 1) {
427  currentResultScalableVector = vector::ExtractOp::create(
428  rewriter, loc, result, llvm::ArrayRef(resIdx).drop_back());
429  } else {
430  currentResultScalableVector = result;
431  }
432  }
433  if (minExtractionSize < minResultTrailingSize) {
434  currentResultScalableVector = vector::ScalableInsertOp::create(
435  rewriter, loc, sourceSubVector, currentResultScalableVector,
436  resIdx.back());
437  }
438 
439  // 3. Update the source and result scalable vectors if needed.
440  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
441  currentResultScalableVector != result) {
442  // Finished row of result. Insert complete scalable vector into result
443  // (n-D) vector.
444  result = vector::InsertOp::create(rewriter, loc,
445  currentResultScalableVector, result,
446  llvm::ArrayRef(resIdx).drop_back());
447  currentResultScalableVector = {};
448  }
449  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
450  // Finished row of source.
451  currentSourceScalableVector = {};
452  }
453 
454  // 4. Increment the insert/extract indices, stepping by minExtractionSize
455  // for the trailing dimensions.
456  inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
457  inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
458  }
459 
460  rewriter.replaceOp(op, result);
461  return success();
462  }
463 
464  static bool isTrailingDimScalable(VectorType type) {
465  return type.getRank() >= 1 && type.getScalableDims().back() &&
466  !llvm::is_contained(type.getScalableDims().drop_back(), true);
467  }
468 };
469 
470 } // namespace
471 
474  patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
475  patterns.getContext(), benefit);
476 }
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314