MLIR  21.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #define DEBUG_TYPE "vector-shape-cast-lowering"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 /// Increments n-D `indices` by `step` starting from the innermost dimension.
31 static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
32  int step = 1) {
33  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34  assert(indices[dim] < vecType.getDimSize(dim) &&
35  "Indices are out of bound");
36  indices[dim] += step;
37  if (indices[dim] < vecType.getDimSize(dim))
38  break;
39 
40  indices[dim] = 0;
41  step = 1;
42  }
43 }
44 
45 namespace {
46 /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
47 /// vectors progressively. This iterates over the n-1 major dimensions of the
48 /// n-D vector and performs rewrites into:
49 /// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
50 class ShapeCastOpNDDownCastRewritePattern
51  : public OpRewritePattern<vector::ShapeCastOp> {
52 public:
54 
55  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
56  PatternRewriter &rewriter) const override {
57  auto sourceVectorType = op.getSourceVectorType();
58  auto resultVectorType = op.getResultVectorType();
59  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
60  return failure();
61 
62  int64_t srcRank = sourceVectorType.getRank();
63  int64_t resRank = resultVectorType.getRank();
64  if (srcRank < 2 || resRank != 1)
65  return failure();
66 
67  // Compute the number of 1-D vector elements involved in the reshape.
68  int64_t numElts = 1;
69  for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70  numElts *= sourceVectorType.getDimSize(dim);
71 
72  auto loc = op.getLoc();
73  SmallVector<int64_t> srcIdx(srcRank - 1, 0);
74  SmallVector<int64_t> resIdx(resRank, 0);
75  int64_t extractSize = sourceVectorType.getShape().back();
76  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
77 
78  // Compute the indices of each 1-D vector element of the source extraction
79  // and destination slice insertion and generate such instructions.
80  for (int64_t i = 0; i < numElts; ++i) {
81  if (i != 0) {
82  incIdx(srcIdx, sourceVectorType, /*step=*/1);
83  incIdx(resIdx, resultVectorType, /*step=*/extractSize);
84  }
85 
86  Value extract =
87  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
88  result = rewriter.create<vector::InsertStridedSliceOp>(
89  loc, extract, result,
90  /*offsets=*/resIdx, /*strides=*/1);
91  }
92 
93  rewriter.replaceOp(op, result);
94  return success();
95  }
96 };
97 
98 /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
99 /// vectors progressively. This iterates over the n-1 major dimension of the n-D
100 /// vector and performs rewrites into:
101 /// vector.extract_strided_slice from 1-D + vector.insert into n-D
102 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
103 class ShapeCastOpNDUpCastRewritePattern
104  : public OpRewritePattern<vector::ShapeCastOp> {
105 public:
107 
108  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
109  PatternRewriter &rewriter) const override {
110  auto sourceVectorType = op.getSourceVectorType();
111  auto resultVectorType = op.getResultVectorType();
112  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
113  return failure();
114 
115  int64_t srcRank = sourceVectorType.getRank();
116  int64_t resRank = resultVectorType.getRank();
117  if (srcRank != 1 || resRank < 2)
118  return failure();
119 
120  // Compute the number of 1-D vector elements involved in the reshape.
121  int64_t numElts = 1;
122  for (int64_t dim = 0; dim < resRank - 1; ++dim)
123  numElts *= resultVectorType.getDimSize(dim);
124 
125  // Compute the indices of each 1-D vector element of the source slice
126  // extraction and destination insertion and generate such instructions.
127  auto loc = op.getLoc();
128  SmallVector<int64_t> srcIdx(srcRank, 0);
129  SmallVector<int64_t> resIdx(resRank - 1, 0);
130  int64_t extractSize = resultVectorType.getShape().back();
131  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
132  for (int64_t i = 0; i < numElts; ++i) {
133  if (i != 0) {
134  incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
135  incIdx(resIdx, resultVectorType, /*step=*/1);
136  }
137 
138  Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
139  loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
140  /*strides=*/1);
141  result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
142  }
143  rewriter.replaceOp(op, result);
144  return success();
145  }
146 };
147 
148 // We typically should not lower general shape cast operations into data
149 // movement instructions, since the assumption is that these casts are
150 // optimized away during progressive lowering. For completeness, however,
151 // we fall back to a reference implementation that moves all elements
152 // into the right place if we get here.
153 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
154 public:
156 
157  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
158  PatternRewriter &rewriter) const override {
159  Location loc = op.getLoc();
160  auto sourceVectorType = op.getSourceVectorType();
161  auto resultVectorType = op.getResultVectorType();
162 
163  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
164  return failure();
165 
166  // Special case for n-D / 1-D lowerings with better implementations.
167  int64_t srcRank = sourceVectorType.getRank();
168  int64_t resRank = resultVectorType.getRank();
169  if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
170  return failure();
171 
172  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
173  // extract/insert chains.
174  int64_t numElts = 1;
175  for (int64_t r = 0; r < srcRank; r++)
176  numElts *= sourceVectorType.getDimSize(r);
177  // Replace with data movement operations:
178  // x[0,0,0] = y[0,0]
179  // x[0,0,1] = y[0,1]
180  // x[0,1,0] = y[0,2]
181  // etc., incrementing the two index vectors "row-major"
182  // within the source and result shape.
183  SmallVector<int64_t> srcIdx(srcRank, 0);
184  SmallVector<int64_t> resIdx(resRank, 0);
185  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
186  for (int64_t i = 0; i < numElts; i++) {
187  if (i != 0) {
188  incIdx(srcIdx, sourceVectorType);
189  incIdx(resIdx, resultVectorType);
190  }
191 
192  Value extract =
193  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
194  result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
195  }
196  rewriter.replaceOp(op, result);
197  return success();
198  }
199 };
200 
201 /// A shape_cast lowering for scalable vectors with a single trailing scalable
202 /// dimension. This is similar to the general shape_cast lowering but makes use
203 /// of vector.scalable.insert and vector.scalable.extract to move elements a
204 /// subvector at a time.
205 ///
206 /// E.g.:
207 /// ```
208 /// // Flatten scalable vector
209 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
210 /// ```
211 /// is rewritten to:
212 /// ```
213 /// // Flatten scalable vector
214 /// %c = arith.constant dense<0> : vector<[8]xi32>
215 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
216 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
217 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
218 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
219 /// ```
220 /// or:
221 /// ```
222 /// // Un-flatten scalable vector
223 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
224 /// ```
225 /// is rewritten to:
226 /// ```
227 /// // Un-flatten scalable vector
228 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
229 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
230 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
231 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
232 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
233 /// ```
234 class ScalableShapeCastOpRewritePattern
235  : public OpRewritePattern<vector::ShapeCastOp> {
236 public:
238 
239  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
240  PatternRewriter &rewriter) const override {
241 
242  Location loc = op.getLoc();
243  auto sourceVectorType = op.getSourceVectorType();
244  auto resultVectorType = op.getResultVectorType();
245  auto srcRank = sourceVectorType.getRank();
246  auto resRank = resultVectorType.getRank();
247 
248  // This can only lower shape_casts where both the source and result types
249  // have a single trailing scalable dimension. This is because there are no
250  // legal representation of other scalable types in LLVM (and likely won't be
251  // soon). There are also (currently) no operations that can index or extract
252  // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
253  if (!isTrailingDimScalable(sourceVectorType) ||
254  !isTrailingDimScalable(resultVectorType)) {
255  return failure();
256  }
257 
258  // The sizes of the trailing dimension of the source and result vectors, the
259  // size of subvector to move, and the number of elements in the vectors.
260  // These are "min" sizes as they are the size when vscale == 1.
261  auto minSourceTrailingSize = sourceVectorType.getShape().back();
262  auto minResultTrailingSize = resultVectorType.getShape().back();
263  auto minExtractionSize =
264  std::min(minSourceTrailingSize, minResultTrailingSize);
265  int64_t minNumElts = 1;
266  for (auto size : sourceVectorType.getShape())
267  minNumElts *= size;
268 
269  // The subvector type to move from the source to the result. Note that this
270  // is a scalable vector. This rewrite will generate code in terms of the
271  // "min" size (vscale == 1 case), that scales to any vscale.
272  auto extractionVectorType = VectorType::get(
273  {minExtractionSize}, sourceVectorType.getElementType(), {true});
274 
275  Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
276  SmallVector<int64_t> srcIdx(srcRank, 0);
277  SmallVector<int64_t> resIdx(resRank, 0);
278 
279  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
280  // once D150000 lands.
281  Value currentResultScalableVector;
282  Value currentSourceScalableVector;
283  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
284  // 1. Extract a scalable subvector from the source vector.
285  if (!currentSourceScalableVector) {
286  if (srcRank != 1) {
287  currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
288  loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
289  } else {
290  currentSourceScalableVector = op.getSource();
291  }
292  }
293  Value sourceSubVector = currentSourceScalableVector;
294  if (minExtractionSize < minSourceTrailingSize) {
295  sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
296  loc, extractionVectorType, sourceSubVector, srcIdx.back());
297  }
298 
299  // 2. Insert the scalable subvector into the result vector.
300  if (!currentResultScalableVector) {
301  if (minExtractionSize == minResultTrailingSize) {
302  currentResultScalableVector = sourceSubVector;
303  } else if (resRank != 1) {
304  currentResultScalableVector = rewriter.create<vector::ExtractOp>(
305  loc, result, llvm::ArrayRef(resIdx).drop_back());
306  } else {
307  currentResultScalableVector = result;
308  }
309  }
310  if (minExtractionSize < minResultTrailingSize) {
311  currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
312  loc, sourceSubVector, currentResultScalableVector, resIdx.back());
313  }
314 
315  // 3. Update the source and result scalable vectors if needed.
316  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
317  currentResultScalableVector != result) {
318  // Finished row of result. Insert complete scalable vector into result
319  // (n-D) vector.
320  result = rewriter.create<vector::InsertOp>(
321  loc, currentResultScalableVector, result,
322  llvm::ArrayRef(resIdx).drop_back());
323  currentResultScalableVector = {};
324  }
325  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
326  // Finished row of source.
327  currentSourceScalableVector = {};
328  }
329 
330  // 4. Increment the insert/extract indices, stepping by minExtractionSize
331  // for the trailing dimensions.
332  incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
333  incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
334  }
335 
336  rewriter.replaceOp(op, result);
337  return success();
338  }
339 
340  static bool isTrailingDimScalable(VectorType type) {
341  return type.getRank() >= 1 && type.getScalableDims().back() &&
342  !llvm::is_contained(type.getScalableDims().drop_back(), true);
343  }
344 };
345 
346 } // namespace
347 
350  patterns.add<ShapeCastOpNDDownCastRewritePattern,
351  ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
352  ScalableShapeCastOpRewritePattern>(patterns.getContext(),
353  benefit);
354 }
static void incIdx(SmallVectorImpl< int64_t > &indices, VectorType vecType, int step=1)
Increments n-D indices by step starting from the innermost dimension.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323