MLIR  19.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/Location.h"
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
36 
37 #define DEBUG_TYPE "vector-shape-cast-lowering"
38 
39 using namespace mlir;
40 using namespace mlir::vector;
41 
42 namespace {
43 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
44 /// vectors progressively on the way to target llvm.matrix intrinsics.
45 /// This iterates over the most major dimension of the 2-D vector and performs
46 /// rewrites into:
47 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
48 class ShapeCastOp2DDownCastRewritePattern
49  : public OpRewritePattern<vector::ShapeCastOp> {
50 public:
52 
53  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
54  PatternRewriter &rewriter) const override {
55  auto sourceVectorType = op.getSourceVectorType();
56  auto resultVectorType = op.getResultVectorType();
57 
58  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
59  return failure();
60 
61  if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
62  return failure();
63 
64  auto loc = op.getLoc();
65  Value desc = rewriter.create<arith::ConstantOp>(
66  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
67  unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
68  for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
69  Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
70  desc = rewriter.create<vector::InsertStridedSliceOp>(
71  loc, vec, desc,
72  /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
73  }
74  rewriter.replaceOp(op, desc);
75  return success();
76  }
77 };
78 
79 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
80 /// vectors progressively.
81 /// This iterates over the most major dimension of the 2-D vector and performs
82 /// rewrites into:
83 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
84 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
85 class ShapeCastOp2DUpCastRewritePattern
86  : public OpRewritePattern<vector::ShapeCastOp> {
87 public:
89 
90  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
91  PatternRewriter &rewriter) const override {
92  auto sourceVectorType = op.getSourceVectorType();
93  auto resultVectorType = op.getResultVectorType();
94 
95  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
96  return failure();
97 
98  if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
99  return failure();
100 
101  auto loc = op.getLoc();
102  Value desc = rewriter.create<arith::ConstantOp>(
103  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
104  unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
105  for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
106  Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
107  loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
108  /*sizes=*/mostMinorVectorSize,
109  /*strides=*/1);
110  desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
111  }
112  rewriter.replaceOp(op, desc);
113  return success();
114  }
115 };
116 
117 static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
118  int dimIdx, int initialStep = 1) {
119  int step = initialStep;
120  for (int d = dimIdx; d >= 0; d--) {
121  idx[d] += step;
122  if (idx[d] >= tp.getDimSize(d)) {
123  idx[d] = 0;
124  step = 1;
125  } else {
126  break;
127  }
128  }
129 }
130 
131 // We typically should not lower general shape cast operations into data
132 // movement instructions, since the assumption is that these casts are
133 // optimized away during progressive lowering. For completeness, however,
134 // we fall back to a reference implementation that moves all elements
135 // into the right place if we get here.
136 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
137 public:
139 
140  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
141  PatternRewriter &rewriter) const override {
142  Location loc = op.getLoc();
143  auto sourceVectorType = op.getSourceVectorType();
144  auto resultVectorType = op.getResultVectorType();
145 
146  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
147  return failure();
148 
149  // Special case 2D / 1D lowerings with better implementations.
150  // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
151  int64_t srcRank = sourceVectorType.getRank();
152  int64_t resRank = resultVectorType.getRank();
153  if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
154  return failure();
155 
156  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
157  // extract/insert chains.
158  // TODO: consider evolving the semantics to only allow 1D source or dest and
159  // drop this potentially very expensive lowering.
160  // Compute number of elements involved in the reshape.
161  int64_t numElts = 1;
162  for (int64_t r = 0; r < srcRank; r++)
163  numElts *= sourceVectorType.getDimSize(r);
164  // Replace with data movement operations:
165  // x[0,0,0] = y[0,0]
166  // x[0,0,1] = y[0,1]
167  // x[0,1,0] = y[0,2]
168  // etc., incrementing the two index vectors "row-major"
169  // within the source and result shape.
170  SmallVector<int64_t> srcIdx(srcRank);
171  SmallVector<int64_t> resIdx(resRank);
172  Value result = rewriter.create<arith::ConstantOp>(
173  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
174  for (int64_t i = 0; i < numElts; i++) {
175  if (i != 0) {
176  incIdx(srcIdx, sourceVectorType, srcRank - 1);
177  incIdx(resIdx, resultVectorType, resRank - 1);
178  }
179 
180  Value extract;
181  if (srcRank == 0) {
182  // 0-D vector special case
183  assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
184  extract = rewriter.create<vector::ExtractElementOp>(
185  loc, op.getSourceVectorType().getElementType(), op.getSource());
186  } else {
187  extract =
188  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
189  }
190 
191  if (resRank == 0) {
192  // 0-D vector special case
193  assert(resIdx.empty() && "Unexpected indices for 0-D vector");
194  result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
195  } else {
196  result =
197  rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
198  }
199  }
200  rewriter.replaceOp(op, result);
201  return success();
202  }
203 };
204 
205 /// A shape_cast lowering for scalable vectors with a single trailing scalable
206 /// dimension. This is similar to the general shape_cast lowering but makes use
207 /// of vector.scalable.insert and vector.scalable.extract to move elements a
208 /// subvector at a time.
209 ///
210 /// E.g.:
211 /// ```
212 /// // Flatten scalable vector
213 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
214 /// ```
215 /// is rewritten to:
216 /// ```
217 /// // Flatten scalable vector
218 /// %c = arith.constant dense<0> : vector<[8]xi32>
219 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
220 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
221 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
222 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
223 /// ```
224 /// or:
225 /// ```
226 /// // Un-flatten scalable vector
227 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
228 /// ```
229 /// is rewritten to:
230 /// ```
231 /// // Un-flatten scalable vector
232 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
233 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
234 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
235 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
236 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
237 /// ```
238 class ScalableShapeCastOpRewritePattern
239  : public OpRewritePattern<vector::ShapeCastOp> {
240 public:
242 
243  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
244  PatternRewriter &rewriter) const override {
245 
246  Location loc = op.getLoc();
247  auto sourceVectorType = op.getSourceVectorType();
248  auto resultVectorType = op.getResultVectorType();
249  auto srcRank = sourceVectorType.getRank();
250  auto resRank = resultVectorType.getRank();
251 
252  // This can only lower shape_casts where both the source and result types
253  // have a single trailing scalable dimension. This is because there are no
254  // legal representation of other scalable types in LLVM (and likely won't be
255  // soon). There are also (currently) no operations that can index or extract
256  // from >= 2D scalable vectors or scalable vectors of fixed vectors.
257  if (!isTrailingDimScalable(sourceVectorType) ||
258  !isTrailingDimScalable(resultVectorType)) {
259  return failure();
260  }
261 
262  // The sizes of the trailing dimension of the source and result vectors, the
263  // size of subvector to move, and the number of elements in the vectors.
264  // These are "min" sizes as they are the size when vscale == 1.
265  auto minSourceTrailingSize = sourceVectorType.getShape().back();
266  auto minResultTrailingSize = resultVectorType.getShape().back();
267  auto minExtractionSize =
268  std::min(minSourceTrailingSize, minResultTrailingSize);
269  int64_t minNumElts = 1;
270  for (auto size : sourceVectorType.getShape())
271  minNumElts *= size;
272 
273  // The subvector type to move from the source to the result. Note that this
274  // is a scalable vector. This rewrite will generate code in terms of the
275  // "min" size (vscale == 1 case), that scales to any vscale.
276  auto extractionVectorType = VectorType::get(
277  {minExtractionSize}, sourceVectorType.getElementType(), {true});
278 
279  Value result = rewriter.create<arith::ConstantOp>(
280  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
281 
282  SmallVector<int64_t> srcIdx(srcRank);
283  SmallVector<int64_t> resIdx(resRank);
284 
285  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
286  // once D150000 lands.
287  Value currentResultScalableVector;
288  Value currentSourceScalableVector;
289  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
290  // 1. Extract a scalable subvector from the source vector.
291  if (!currentSourceScalableVector) {
292  if (srcRank != 1) {
293  currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
294  loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
295  } else {
296  currentSourceScalableVector = op.getSource();
297  }
298  }
299  Value sourceSubVector = currentSourceScalableVector;
300  if (minExtractionSize < minSourceTrailingSize) {
301  sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
302  loc, extractionVectorType, sourceSubVector, srcIdx.back());
303  }
304 
305  // 2. Insert the scalable subvector into the result vector.
306  if (!currentResultScalableVector) {
307  if (minExtractionSize == minResultTrailingSize) {
308  currentResultScalableVector = sourceSubVector;
309  } else if (resRank != 1) {
310  currentResultScalableVector = rewriter.create<vector::ExtractOp>(
311  loc, result, llvm::ArrayRef(resIdx).drop_back());
312  } else {
313  currentResultScalableVector = result;
314  }
315  }
316  if (minExtractionSize < minResultTrailingSize) {
317  currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
318  loc, sourceSubVector, currentResultScalableVector, resIdx.back());
319  }
320 
321  // 3. Update the source and result scalable vectors if needed.
322  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
323  currentResultScalableVector != result) {
324  // Finished row of result. Insert complete scalable vector into result
325  // (n-D) vector.
326  result = rewriter.create<vector::InsertOp>(
327  loc, currentResultScalableVector, result,
328  llvm::ArrayRef(resIdx).drop_back());
329  currentResultScalableVector = {};
330  }
331  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
332  // Finished row of source.
333  currentSourceScalableVector = {};
334  }
335 
336  // 4. Increment the insert/extract indices, stepping by minExtractionSize
337  // for the trailing dimensions.
338  incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
339  incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
340  }
341 
342  rewriter.replaceOp(op, result);
343  return success();
344  }
345 
346  static bool isTrailingDimScalable(VectorType type) {
347  return type.getRank() >= 1 && type.getScalableDims().back() &&
348  !llvm::is_contained(type.getScalableDims().drop_back(), true);
349  }
350 };
351 
352 } // namespace
353 
355  RewritePatternSet &patterns, PatternBenefit benefit) {
356  patterns.add<ShapeCastOp2DDownCastRewritePattern,
357  ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358  ScalableShapeCastOpRewritePattern>(patterns.getContext(),
359  benefit);
360 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362