MLIR  21.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include <numeric>
25 
26 #define DEBUG_TYPE "vector-shape-cast-lowering"
27 
28 using namespace mlir;
29 
30 /// Perform the inplace update
31 /// rhs <- lhs + rhs
32 ///
33 /// where `rhs` is a number expressed in mixed base `base` with most signficant
34 /// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is
35 /// {5,3,2} then `rhs` has value a*3*2 + b*2 + c.
36 ///
37 /// Some examples where `base` is {5,3,2}:
38 /// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1}
39 /// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0}
40 /// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1}
41 ///
42 /// Invalid:
43 /// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2}
44 ///
45 /// Overflows not handled correctly:
46 /// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1})
47 static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base,
49 
50  // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]:
51  for (int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
52  int64_t dimBase = base[dim];
53  assert(rhs[dim] < dimBase && "rhs not in base");
54 
55  int64_t incremented = rhs[dim] + lhs;
56 
57  // If the incremented value excedes the dimension base, we must spill to the
58  // next most significant dimension and repeat (we might need to spill to
59  // more significant dimensions multiple times).
60  lhs = incremented / dimBase;
61  rhs[dim] = incremented % dimBase;
62  if (lhs == 0)
63  break;
64  }
65 }
66 
67 namespace {
68 
69 /// shape_cast is converted to a sequence of extract, extract_strided_slice,
70 /// insert_strided_slice, and insert operations. The running example will be:
71 ///
72 /// %0 = vector.shape_cast %arg0 :
73 /// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8>
74 ///
75 /// In this example the source and result shapes share a common suffix of 7x11.
76 /// This means we can always decompose the shape_cast into extract, insert, and
77 /// their strided equivalents, on vectors with shape suffix 7x11.
78 ///
79 /// The greatest common divisor (gcd) of the first dimension preceding the
80 /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate
81 /// on vectors with shapes that are `multiples` of (what we define as) the
82 /// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`.
83 ///
84 /// vector<2x2x3x4x7x11xi8> to
85 /// vector<8x6x7x11xi8>
86 /// | ||||
87 /// | ++++------------> common suffix of 7x11
88 /// +-----------------> gcd(4,6) is 2 | |
89 /// | | |
90 /// v v v
91 /// atomic shape <----- 2x7x11
92 ///
93 ///
94 ///
95 /// The decomposition implemented in this pattern consists of a sequence of
96 /// repeated steps:
97 ///
98 /// (1) Extract vectors from the suffix of the source.
99 /// In our example this is 2x2x3x4x7x11 -> 4x7x11.
100 ///
101 /// (2) Do extract_strided_slice down to the atomic shape.
102 /// In our example this is 4x7x11 -> 2x7x11.
103 ///
104 /// (3) Do insert_strided_slice to the suffix of the result.
105 /// In our example this is 2x7x11 -> 6x7x11.
106 ///
107 /// (4) insert these vectors into the result vector.
108 /// In our example this is 6x7x11 -> 8x6x7x11.
109 ///
110 /// These steps occur with different periods. In this example
111 /// (1) occurs 12 times,
112 /// (2) and (3) occur 24 times, and
113 /// (4) occurs 8 times.
114 ///
115 /// Two special cases are handled independently in this pattern
116 /// (i) A shape_cast that just does leading 1 insertion/removal
117 /// (ii) A shape_cast where the gcd is 1.
118 ///
119 /// These 2 cases can have more compact IR generated by not using the generic
120 /// algorithm described above.
121 ///
122 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
123 
124  // Case (i) of description.
125  // Assumes source and result shapes are identical up to some leading ones.
126  static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
127  PatternRewriter &rewriter) {
128 
129  const Location loc = shapeCast.getLoc();
130  const VectorType sourceType = shapeCast.getSourceVectorType();
131  const VectorType resultType = shapeCast.getResultVectorType();
132 
133  const int64_t sourceRank = sourceType.getRank();
134  const int64_t resultRank = resultType.getRank();
135  const int64_t delta = sourceRank - resultRank;
136  const int64_t sourceLeading = delta > 0 ? delta : 0;
137  const int64_t resultLeading = delta > 0 ? 0 : -delta;
138 
139  const Value source = shapeCast.getSource();
140  const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
141  const Value extracted = rewriter.create<vector::ExtractOp>(
142  loc, source, SmallVector<int64_t>(sourceLeading, 0));
143  const Value result = rewriter.create<vector::InsertOp>(
144  loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
145 
146  rewriter.replaceOp(shapeCast, result);
147  return success();
148  }
149 
150  // Case (ii) of description.
151  // Assumes a shape_cast where the suffix shape of the source starting at
152  // `sourceDim` and the suffix shape of the result starting at `resultDim` are
153  // identical.
154  static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
155  int64_t sourceDim,
156  int64_t resultDim,
157  PatternRewriter &rewriter) {
158 
159  const Location loc = shapeCast.getLoc();
160 
161  const Value source = shapeCast.getSource();
162  const ArrayRef<int64_t> sourceShape =
163  shapeCast.getSourceVectorType().getShape();
164 
165  const VectorType resultType = shapeCast.getResultVectorType();
166  const ArrayRef<int64_t> resultShape = resultType.getShape();
167 
168  const int64_t nSlices =
169  std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
170  std::multiplies<int64_t>());
171 
172  SmallVector<int64_t> extractIndex(sourceDim, 0);
173  SmallVector<int64_t> insertIndex(resultDim, 0);
174  Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
175 
176  for (int i = 0; i < nSlices; ++i) {
177  Value extracted =
178  rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
179 
180  result = rewriter.create<vector::InsertOp>(loc, extracted, result,
181  insertIndex);
182 
183  inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
184  inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
185  }
186  rewriter.replaceOp(shapeCast, result);
187  return success();
188  }
189 
190 public:
192 
193  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
194  PatternRewriter &rewriter) const override {
195  Location loc = op.getLoc();
196  VectorType sourceType = op.getSourceVectorType();
197  VectorType resultType = op.getResultVectorType();
198 
199  if (sourceType.isScalable() || resultType.isScalable())
200  return rewriter.notifyMatchFailure(
201  op,
202  "shape_cast where vectors are scalable not handled by this pattern");
203 
204  const ArrayRef<int64_t> sourceShape = sourceType.getShape();
205  const ArrayRef<int64_t> resultShape = resultType.getShape();
206  const int64_t sourceRank = sourceType.getRank();
207  const int64_t resultRank = resultType.getRank();
208  const int64_t numElms = sourceType.getNumElements();
209  const Value source = op.getSource();
210 
211  // Set the first dimension (starting at the end) in the source and result
212  // respectively where the dimension sizes differ. Using the running example:
213  //
214  // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ]
215  // shapes: (2,2,3,4,7,11) -> (8,6,7,11)
216  // ^ ^
217  // | |
218  // sourceSuffixStartDim is 3 |
219  // |
220  // resultSuffixStartDim is 1
221  int64_t sourceSuffixStartDim = sourceRank - 1;
222  int64_t resultSuffixStartDim = resultRank - 1;
223  while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
224  (sourceType.getDimSize(sourceSuffixStartDim) ==
225  resultType.getDimSize(resultSuffixStartDim))) {
226  --sourceSuffixStartDim;
227  --resultSuffixStartDim;
228  }
229 
230  // This is the case (i) where there are just some leading ones to contend
231  // with in the source or result. It can be handled with a single
232  // extract/insert pair.
233  if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
234  return leadingOnesLowering(op, rewriter);
235 
236  const int64_t sourceSuffixStartDimSize =
237  sourceType.getDimSize(sourceSuffixStartDim);
238  const int64_t resultSuffixStartDimSize =
239  resultType.getDimSize(resultSuffixStartDim);
240  const int64_t greatestCommonDivisor =
241  std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
242  const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
243  const size_t extractPeriod =
244  sourceSuffixStartDimSize / greatestCommonDivisor;
245  const size_t insertPeriod =
246  resultSuffixStartDimSize / greatestCommonDivisor;
247 
248  SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
249  sourceShape.end());
250  atomicShape[0] = greatestCommonDivisor;
251 
252  const int64_t numAtomicElms = std::accumulate(
253  atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
254  const size_t nAtomicSlices = numElms / numAtomicElms;
255 
256  // This is the case (ii) where the strided dimension size is 1. More compact
257  // IR is generated in this case if we just extract and insert the elements
258  // directly. In other words, we don't use extract_strided_slice and
259  // insert_strided_slice.
260  if (greatestCommonDivisor == 1)
261  return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
262  resultSuffixStartDim + 1, rewriter);
263 
264  // The insert_strided_slice result's type
265  const ArrayRef<int64_t> insertStridedShape =
266  resultShape.drop_front(resultSuffixStartDim);
267  const VectorType insertStridedType =
268  VectorType::get(insertStridedShape, resultType.getElementType());
269 
270  SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
271  SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
272  SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
273  SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
274  const SmallVector<int64_t> sizes(stridedSliceRank, 1);
275 
276  Value extracted = {};
277  Value extractedStrided = {};
278  Value insertedSlice = {};
279  Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
280  const Value partResult =
281  rewriter.create<ub::PoisonOp>(loc, insertStridedType);
282 
283  for (size_t i = 0; i < nAtomicSlices; ++i) {
284 
285  const size_t extractStridedPhase = i % extractPeriod;
286  const size_t insertStridedPhase = i % insertPeriod;
287 
288  // vector.extract
289  if (extractStridedPhase == 0) {
290  extracted =
291  rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
292  inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
293  extractIndex);
294  }
295 
296  // vector.extract_strided_slice
297  extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
298  extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>(
299  loc, extracted, extractOffsets, atomicShape, sizes);
300 
301  // vector.insert_strided_slice
302  if (insertStridedPhase == 0) {
303  insertedSlice = partResult;
304  }
305  insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
306  insertedSlice = rewriter.create<vector::InsertStridedSliceOp>(
307  loc, extractedStrided, insertedSlice, insertOffsets, sizes);
308 
309  // vector.insert
310  if (insertStridedPhase + 1 == insertPeriod) {
311  result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result,
312  insertIndex);
313  inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
314  insertIndex);
315  }
316  }
317  rewriter.replaceOp(op, result);
318  return success();
319  }
320 };
321 
322 /// A shape_cast lowering for scalable vectors with a single trailing scalable
323 /// dimension. This is similar to the general shape_cast lowering but makes use
324 /// of vector.scalable.insert and vector.scalable.extract to move elements a
325 /// subvector at a time.
326 ///
327 /// E.g.:
328 /// ```
329 /// // Flatten scalable vector
330 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
331 /// ```
332 /// is rewritten to:
333 /// ```
334 /// // Flatten scalable vector
335 /// %c = arith.constant dense<0> : vector<[8]xi32>
336 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
337 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
338 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
339 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
340 /// ```
341 /// or:
342 /// ```
343 /// // Un-flatten scalable vector
344 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
345 /// ```
346 /// is rewritten to:
347 /// ```
348 /// // Un-flatten scalable vector
349 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
350 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
351 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
352 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
353 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
354 /// ```
355 class ScalableShapeCastOpRewritePattern
356  : public OpRewritePattern<vector::ShapeCastOp> {
357 public:
359 
360  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
361  PatternRewriter &rewriter) const override {
362 
363  Location loc = op.getLoc();
364  auto sourceVectorType = op.getSourceVectorType();
365  auto resultVectorType = op.getResultVectorType();
366  auto srcRank = sourceVectorType.getRank();
367  auto resRank = resultVectorType.getRank();
368 
369  // This can only lower shape_casts where both the source and result types
370  // have a single trailing scalable dimension. This is because there are no
371  // legal representation of other scalable types in LLVM (and likely won't be
372  // soon). There are also (currently) no operations that can index or extract
373  // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
374  if (!isTrailingDimScalable(sourceVectorType) ||
375  !isTrailingDimScalable(resultVectorType)) {
376  return rewriter.notifyMatchFailure(
377  op, "trailing dims are not scalable, not handled by this pattern");
378  }
379 
380  // The sizes of the trailing dimension of the source and result vectors, the
381  // size of subvector to move, and the number of elements in the vectors.
382  // These are "min" sizes as they are the size when vscale == 1.
383  auto minSourceTrailingSize = sourceVectorType.getShape().back();
384  auto minResultTrailingSize = resultVectorType.getShape().back();
385  auto minExtractionSize =
386  std::min(minSourceTrailingSize, minResultTrailingSize);
387  int64_t minNumElts = 1;
388  for (auto size : sourceVectorType.getShape())
389  minNumElts *= size;
390 
391  // The subvector type to move from the source to the result. Note that this
392  // is a scalable vector. This rewrite will generate code in terms of the
393  // "min" size (vscale == 1 case), that scales to any vscale.
394  auto extractionVectorType = VectorType::get(
395  {minExtractionSize}, sourceVectorType.getElementType(), {true});
396 
397  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
398  SmallVector<int64_t> srcIdx(srcRank, 0);
399  SmallVector<int64_t> resIdx(resRank, 0);
400 
401  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
402  // once D150000 lands.
403  Value currentResultScalableVector;
404  Value currentSourceScalableVector;
405  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
406  // 1. Extract a scalable subvector from the source vector.
407  if (!currentSourceScalableVector) {
408  if (srcRank != 1) {
409  currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
410  loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
411  } else {
412  currentSourceScalableVector = op.getSource();
413  }
414  }
415  Value sourceSubVector = currentSourceScalableVector;
416  if (minExtractionSize < minSourceTrailingSize) {
417  sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
418  loc, extractionVectorType, sourceSubVector, srcIdx.back());
419  }
420 
421  // 2. Insert the scalable subvector into the result vector.
422  if (!currentResultScalableVector) {
423  if (minExtractionSize == minResultTrailingSize) {
424  currentResultScalableVector = sourceSubVector;
425  } else if (resRank != 1) {
426  currentResultScalableVector = rewriter.create<vector::ExtractOp>(
427  loc, result, llvm::ArrayRef(resIdx).drop_back());
428  } else {
429  currentResultScalableVector = result;
430  }
431  }
432  if (minExtractionSize < minResultTrailingSize) {
433  currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
434  loc, sourceSubVector, currentResultScalableVector, resIdx.back());
435  }
436 
437  // 3. Update the source and result scalable vectors if needed.
438  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
439  currentResultScalableVector != result) {
440  // Finished row of result. Insert complete scalable vector into result
441  // (n-D) vector.
442  result = rewriter.create<vector::InsertOp>(
443  loc, currentResultScalableVector, result,
444  llvm::ArrayRef(resIdx).drop_back());
445  currentResultScalableVector = {};
446  }
447  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
448  // Finished row of source.
449  currentSourceScalableVector = {};
450  }
451 
452  // 4. Increment the insert/extract indices, stepping by minExtractionSize
453  // for the trailing dimensions.
454  inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
455  inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
456  }
457 
458  rewriter.replaceOp(op, result);
459  return success();
460  }
461 
462  static bool isTrailingDimScalable(VectorType type) {
463  return type.getRank() >= 1 && type.getScalableDims().back() &&
464  !llvm::is_contained(type.getScalableDims().drop_back(), true);
465  }
466 };
467 
468 } // namespace
469 
472  patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
473  patterns.getContext(), benefit);
474 }
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319