MLIR  18.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/Location.h"
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
36 
37 #define DEBUG_TYPE "vector-shape-cast-lowering"
38 
39 using namespace mlir;
40 using namespace mlir::vector;
41 
42 namespace {
43 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
44 /// vectors progressively on the way to target llvm.matrix intrinsics.
45 /// This iterates over the most major dimension of the 2-D vector and performs
46 /// rewrites into:
47 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
48 class ShapeCastOp2DDownCastRewritePattern
49  : public OpRewritePattern<vector::ShapeCastOp> {
50 public:
52 
53  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
54  PatternRewriter &rewriter) const override {
55  auto sourceVectorType = op.getSourceVectorType();
56  auto resultVectorType = op.getResultVectorType();
57 
58  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
59  return failure();
60 
61  if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
62  return failure();
63 
64  auto loc = op.getLoc();
65  Value desc = rewriter.create<arith::ConstantOp>(
66  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
67  unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
68  for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
69  Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
70  desc = rewriter.create<vector::InsertStridedSliceOp>(
71  loc, vec, desc,
72  /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
73  }
74  rewriter.replaceOp(op, desc);
75  return success();
76  }
77 };
78 
79 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
80 /// vectors progressively.
81 /// This iterates over the most major dimension of the 2-D vector and performs
82 /// rewrites into:
83 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
84 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
85 class ShapeCastOp2DUpCastRewritePattern
86  : public OpRewritePattern<vector::ShapeCastOp> {
87 public:
89 
90  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
91  PatternRewriter &rewriter) const override {
92  auto sourceVectorType = op.getSourceVectorType();
93  auto resultVectorType = op.getResultVectorType();
94 
95  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
96  return failure();
97 
98  if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
99  return failure();
100 
101  auto loc = op.getLoc();
102  Value desc = rewriter.create<arith::ConstantOp>(
103  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
104  unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
105  for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
106  Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
107  loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
108  /*sizes=*/mostMinorVectorSize,
109  /*strides=*/1);
110  desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
111  }
112  rewriter.replaceOp(op, desc);
113  return success();
114  }
115 };
116 
117 static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
118  int dimIdx, int initialStep = 1) {
119  int step = initialStep;
120  for (int d = dimIdx; d >= 0; d--) {
121  idx[d] += step;
122  if (idx[d] >= tp.getDimSize(d)) {
123  idx[d] = 0;
124  step = 1;
125  } else {
126  break;
127  }
128  }
129 }
130 
131 // We typically should not lower general shape cast operations into data
132 // movement instructions, since the assumption is that these casts are
133 // optimized away during progressive lowering. For completeness, however,
134 // we fall back to a reference implementation that moves all elements
135 // into the right place if we get here.
136 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
137 public:
139 
140  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
141  PatternRewriter &rewriter) const override {
142  Location loc = op.getLoc();
143  auto sourceVectorType = op.getSourceVectorType();
144  auto resultVectorType = op.getResultVectorType();
145 
146  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
147  return failure();
148 
149  // Special case 2D / 1D lowerings with better implementations.
150  // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
151  int64_t srcRank = sourceVectorType.getRank();
152  int64_t resRank = resultVectorType.getRank();
153  if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
154  return failure();
155 
156  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
157  // extract/insert chains.
158  // TODO: consider evolving the semantics to only allow 1D source or dest and
159  // drop this potentially very expensive lowering.
160  // Compute number of elements involved in the reshape.
161  int64_t numElts = 1;
162  for (int64_t r = 0; r < srcRank; r++)
163  numElts *= sourceVectorType.getDimSize(r);
164  // Replace with data movement operations:
165  // x[0,0,0] = y[0,0]
166  // x[0,0,1] = y[0,1]
167  // x[0,1,0] = y[0,2]
168  // etc., incrementing the two index vectors "row-major"
169  // within the source and result shape.
170  SmallVector<int64_t> srcIdx(srcRank);
171  SmallVector<int64_t> resIdx(resRank);
172  Value result = rewriter.create<arith::ConstantOp>(
173  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
174  for (int64_t i = 0; i < numElts; i++) {
175  if (i != 0) {
176  incIdx(srcIdx, sourceVectorType, srcRank - 1);
177  incIdx(resIdx, resultVectorType, resRank - 1);
178  }
179 
180  Value extract;
181  if (srcRank == 0) {
182  // 0-D vector special case
183  assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
184  extract = rewriter.create<vector::ExtractElementOp>(
185  loc, op.getSourceVectorType().getElementType(), op.getSource());
186  } else {
187  extract =
188  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
189  }
190 
191  if (resRank == 0) {
192  // 0-D vector special case
193  assert(resIdx.empty() && "Unexpected indices for 0-D vector");
194  result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
195  } else {
196  result =
197  rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
198  }
199  }
200  rewriter.replaceOp(op, result);
201  return success();
202  }
203 };
204 
205 /// A shape_cast lowering for scalable vectors with a single trailing scalable
206 /// dimension. This is similar to the general shape_cast lowering but makes use
207 /// of vector.scalable.insert and vector.scalable.extract to move elements a
208 /// subvector at a time.
209 ///
210 /// E.g.:
211 /// ```
212 /// // Flatten scalable vector
213 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
214 /// ```
215 /// is rewritten to:
216 /// ```
217 /// // Flatten scalable vector
218 /// %c = arith.constant dense<0> : vector<[8]xi32>
219 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
220 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
221 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
222 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
223 /// ```
224 /// or:
225 /// ```
226 /// // Un-flatten scalable vector
227 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
228 /// ```
229 /// is rewritten to:
230 /// ```
231 /// // Un-flatten scalable vector
232 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
233 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
234 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
235 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
236 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
237 /// ```
238 class ScalableShapeCastOpRewritePattern
239  : public OpRewritePattern<vector::ShapeCastOp> {
240 public:
242 
243  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
244  PatternRewriter &rewriter) const override {
245 
246  Location loc = op.getLoc();
247  auto sourceVectorType = op.getSourceVectorType();
248  auto resultVectorType = op.getResultVectorType();
249  auto srcRank = sourceVectorType.getRank();
250  auto resRank = resultVectorType.getRank();
251 
252  // This can only lower shape_casts where both the source and result types
253  // have a single trailing scalable dimension. This is because there are no
254  // legal representation of other scalable types in LLVM (and likely won't be
255  // soon). There are also (currently) no operations that can index or extract
256  // from >= 2D scalable vectors or scalable vectors of fixed vectors.
257  if (!isTrailingDimScalable(sourceVectorType) ||
258  !isTrailingDimScalable(resultVectorType)) {
259  return failure();
260  }
261 
262  // The sizes of the trailing dimension of the source and result vectors, the
263  // size of subvector to move, and the number of elements in the vectors.
264  // These are "min" sizes as they are the size when vscale == 1.
265  auto minSourceTrailingSize = sourceVectorType.getShape().back();
266  auto minResultTrailingSize = resultVectorType.getShape().back();
267  auto minExtractionSize =
268  std::min(minSourceTrailingSize, minResultTrailingSize);
269  int64_t minNumElts = 1;
270  for (auto size : sourceVectorType.getShape())
271  minNumElts *= size;
272 
273  // The subvector type to move from the source to the result. Note that this
274  // is a scalable vector. This rewrite will generate code in terms of the
275  // "min" size (vscale == 1 case), that scales to any vscale.
276  auto extractionVectorType = VectorType::get(
277  {minExtractionSize}, sourceVectorType.getElementType(), {true});
278 
279  Value result = rewriter.create<arith::ConstantOp>(
280  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
281 
282  SmallVector<int64_t> srcIdx(srcRank);
283  SmallVector<int64_t> resIdx(resRank);
284 
285  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
286  // once D150000 lands.
287  Value currentResultScalableVector;
288  Value currentSourceScalableVector;
289  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
290  // 1. Extract a scalable subvector from the source vector.
291  if (!currentSourceScalableVector) {
292  if (srcRank != 1) {
293  currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
294  loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
295  } else {
296  currentSourceScalableVector = op.getSource();
297  }
298  }
299  Value sourceSubVector = currentSourceScalableVector;
300  if (minExtractionSize < minSourceTrailingSize) {
301  sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
302  loc, extractionVectorType, sourceSubVector, srcIdx.back());
303  }
304 
305  // 2. Insert the scalable subvector into the result vector.
306  if (!currentResultScalableVector) {
307  if (minExtractionSize == minResultTrailingSize) {
308  currentResultScalableVector = sourceSubVector;
309  } else if (resRank != 1) {
310  currentResultScalableVector = rewriter.create<vector::ExtractOp>(
311  loc, result, llvm::ArrayRef(resIdx).drop_back());
312  } else {
313  currentResultScalableVector = result;
314  }
315  }
316  if (minExtractionSize < minResultTrailingSize) {
317  currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
318  loc, sourceSubVector, currentResultScalableVector, resIdx.back());
319  }
320 
321  // 3. Update the source and result scalable vectors if needed.
322  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
323  currentResultScalableVector != result) {
324  // Finished row of result. Insert complete scalable vector into result
325  // (n-D) vector.
326  result = rewriter.create<vector::InsertOp>(
327  loc, currentResultScalableVector, result,
328  llvm::ArrayRef(resIdx).drop_back());
329  currentResultScalableVector = {};
330  }
331  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
332  // Finished row of source.
333  currentSourceScalableVector = {};
334  }
335 
336  // 4. Increment the insert/extract indices, stepping by minExtractionSize
337  // for the trailing dimensions.
338  incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
339  incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
340  }
341 
342  rewriter.replaceOp(op, result);
343  return success();
344  }
345 
346  static bool isTrailingDimScalable(VectorType type) {
347  return type.getRank() >= 1 && type.getScalableDims().back() &&
348  !llvm::is_contained(type.getScalableDims().drop_back(), true);
349  }
350 };
351 
352 } // namespace
353 
355  RewritePatternSet &patterns, PatternBenefit benefit) {
356  patterns.add<ShapeCastOp2DDownCastRewritePattern,
357  ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358  ScalableShapeCastOpRewritePattern>(patterns.getContext(),
359  benefit);
360 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361