MLIR  20.0.0git
LowerVectorShapeCast.cpp
Go to the documentation of this file.
1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.shape_cast' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/Location.h"
31 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
33 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-shape-cast-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 namespace {
42 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
43 /// vectors progressively on the way to target llvm.matrix intrinsics.
44 /// This iterates over the most major dimension of the 2-D vector and performs
45 /// rewrites into:
46 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
47 class ShapeCastOp2DDownCastRewritePattern
48  : public OpRewritePattern<vector::ShapeCastOp> {
49 public:
51 
52  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
53  PatternRewriter &rewriter) const override {
54  auto sourceVectorType = op.getSourceVectorType();
55  auto resultVectorType = op.getResultVectorType();
56 
57  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
58  return failure();
59 
60  if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
61  return failure();
62 
63  auto loc = op.getLoc();
64  Value desc = rewriter.create<arith::ConstantOp>(
65  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
66  unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
67  for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
68  Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
69  desc = rewriter.create<vector::InsertStridedSliceOp>(
70  loc, vec, desc,
71  /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
72  }
73  rewriter.replaceOp(op, desc);
74  return success();
75  }
76 };
77 
78 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
79 /// vectors progressively.
80 /// This iterates over the most major dimension of the 2-D vector and performs
81 /// rewrites into:
82 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
83 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
84 class ShapeCastOp2DUpCastRewritePattern
85  : public OpRewritePattern<vector::ShapeCastOp> {
86 public:
88 
89  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
90  PatternRewriter &rewriter) const override {
91  auto sourceVectorType = op.getSourceVectorType();
92  auto resultVectorType = op.getResultVectorType();
93 
94  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
95  return failure();
96 
97  if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
98  return failure();
99 
100  auto loc = op.getLoc();
101  Value desc = rewriter.create<arith::ConstantOp>(
102  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
103  unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
104  for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
105  Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
106  loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
107  /*sizes=*/mostMinorVectorSize,
108  /*strides=*/1);
109  desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
110  }
111  rewriter.replaceOp(op, desc);
112  return success();
113  }
114 };
115 
116 static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
117  int dimIdx, int initialStep = 1) {
118  int step = initialStep;
119  for (int d = dimIdx; d >= 0; d--) {
120  idx[d] += step;
121  if (idx[d] >= tp.getDimSize(d)) {
122  idx[d] = 0;
123  step = 1;
124  } else {
125  break;
126  }
127  }
128 }
129 
130 // We typically should not lower general shape cast operations into data
131 // movement instructions, since the assumption is that these casts are
132 // optimized away during progressive lowering. For completeness, however,
133 // we fall back to a reference implementation that moves all elements
134 // into the right place if we get here.
135 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
136 public:
138 
139  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
140  PatternRewriter &rewriter) const override {
141  Location loc = op.getLoc();
142  auto sourceVectorType = op.getSourceVectorType();
143  auto resultVectorType = op.getResultVectorType();
144 
145  if (sourceVectorType.isScalable() || resultVectorType.isScalable())
146  return failure();
147 
148  // Special case 2D / 1D lowerings with better implementations.
149  // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
150  int64_t srcRank = sourceVectorType.getRank();
151  int64_t resRank = resultVectorType.getRank();
152  if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
153  return failure();
154 
155  // Generic ShapeCast lowering path goes all the way down to unrolled scalar
156  // extract/insert chains.
157  // TODO: consider evolving the semantics to only allow 1D source or dest and
158  // drop this potentially very expensive lowering.
159  // Compute number of elements involved in the reshape.
160  int64_t numElts = 1;
161  for (int64_t r = 0; r < srcRank; r++)
162  numElts *= sourceVectorType.getDimSize(r);
163  // Replace with data movement operations:
164  // x[0,0,0] = y[0,0]
165  // x[0,0,1] = y[0,1]
166  // x[0,1,0] = y[0,2]
167  // etc., incrementing the two index vectors "row-major"
168  // within the source and result shape.
169  SmallVector<int64_t> srcIdx(srcRank);
170  SmallVector<int64_t> resIdx(resRank);
171  Value result = rewriter.create<arith::ConstantOp>(
172  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
173  for (int64_t i = 0; i < numElts; i++) {
174  if (i != 0) {
175  incIdx(srcIdx, sourceVectorType, srcRank - 1);
176  incIdx(resIdx, resultVectorType, resRank - 1);
177  }
178 
179  Value extract;
180  if (srcRank == 0) {
181  // 0-D vector special case
182  assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
183  extract = rewriter.create<vector::ExtractElementOp>(
184  loc, op.getSourceVectorType().getElementType(), op.getSource());
185  } else {
186  extract =
187  rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
188  }
189 
190  if (resRank == 0) {
191  // 0-D vector special case
192  assert(resIdx.empty() && "Unexpected indices for 0-D vector");
193  result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
194  } else {
195  result =
196  rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
197  }
198  }
199  rewriter.replaceOp(op, result);
200  return success();
201  }
202 };
203 
204 /// A shape_cast lowering for scalable vectors with a single trailing scalable
205 /// dimension. This is similar to the general shape_cast lowering but makes use
206 /// of vector.scalable.insert and vector.scalable.extract to move elements a
207 /// subvector at a time.
208 ///
209 /// E.g.:
210 /// ```
211 /// // Flatten scalable vector
212 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
213 /// ```
214 /// is rewritten to:
215 /// ```
216 /// // Flatten scalable vector
217 /// %c = arith.constant dense<0> : vector<[8]xi32>
218 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
219 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
220 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
221 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
222 /// ```
223 /// or:
224 /// ```
225 /// // Un-flatten scalable vector
226 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
227 /// ```
228 /// is rewritten to:
229 /// ```
230 /// // Un-flatten scalable vector
231 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
232 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
233 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
234 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
235 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
236 /// ```
237 class ScalableShapeCastOpRewritePattern
238  : public OpRewritePattern<vector::ShapeCastOp> {
239 public:
241 
242  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
243  PatternRewriter &rewriter) const override {
244 
245  Location loc = op.getLoc();
246  auto sourceVectorType = op.getSourceVectorType();
247  auto resultVectorType = op.getResultVectorType();
248  auto srcRank = sourceVectorType.getRank();
249  auto resRank = resultVectorType.getRank();
250 
251  // This can only lower shape_casts where both the source and result types
252  // have a single trailing scalable dimension. This is because there are no
253  // legal representation of other scalable types in LLVM (and likely won't be
254  // soon). There are also (currently) no operations that can index or extract
255  // from >= 2D scalable vectors or scalable vectors of fixed vectors.
256  if (!isTrailingDimScalable(sourceVectorType) ||
257  !isTrailingDimScalable(resultVectorType)) {
258  return failure();
259  }
260 
261  // The sizes of the trailing dimension of the source and result vectors, the
262  // size of subvector to move, and the number of elements in the vectors.
263  // These are "min" sizes as they are the size when vscale == 1.
264  auto minSourceTrailingSize = sourceVectorType.getShape().back();
265  auto minResultTrailingSize = resultVectorType.getShape().back();
266  auto minExtractionSize =
267  std::min(minSourceTrailingSize, minResultTrailingSize);
268  int64_t minNumElts = 1;
269  for (auto size : sourceVectorType.getShape())
270  minNumElts *= size;
271 
272  // The subvector type to move from the source to the result. Note that this
273  // is a scalable vector. This rewrite will generate code in terms of the
274  // "min" size (vscale == 1 case), that scales to any vscale.
275  auto extractionVectorType = VectorType::get(
276  {minExtractionSize}, sourceVectorType.getElementType(), {true});
277 
278  Value result = rewriter.create<arith::ConstantOp>(
279  loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
280 
281  SmallVector<int64_t> srcIdx(srcRank);
282  SmallVector<int64_t> resIdx(resRank);
283 
284  // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
285  // once D150000 lands.
286  Value currentResultScalableVector;
287  Value currentSourceScalableVector;
288  for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
289  // 1. Extract a scalable subvector from the source vector.
290  if (!currentSourceScalableVector) {
291  if (srcRank != 1) {
292  currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
293  loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
294  } else {
295  currentSourceScalableVector = op.getSource();
296  }
297  }
298  Value sourceSubVector = currentSourceScalableVector;
299  if (minExtractionSize < minSourceTrailingSize) {
300  sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
301  loc, extractionVectorType, sourceSubVector, srcIdx.back());
302  }
303 
304  // 2. Insert the scalable subvector into the result vector.
305  if (!currentResultScalableVector) {
306  if (minExtractionSize == minResultTrailingSize) {
307  currentResultScalableVector = sourceSubVector;
308  } else if (resRank != 1) {
309  currentResultScalableVector = rewriter.create<vector::ExtractOp>(
310  loc, result, llvm::ArrayRef(resIdx).drop_back());
311  } else {
312  currentResultScalableVector = result;
313  }
314  }
315  if (minExtractionSize < minResultTrailingSize) {
316  currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
317  loc, sourceSubVector, currentResultScalableVector, resIdx.back());
318  }
319 
320  // 3. Update the source and result scalable vectors if needed.
321  if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
322  currentResultScalableVector != result) {
323  // Finished row of result. Insert complete scalable vector into result
324  // (n-D) vector.
325  result = rewriter.create<vector::InsertOp>(
326  loc, currentResultScalableVector, result,
327  llvm::ArrayRef(resIdx).drop_back());
328  currentResultScalableVector = {};
329  }
330  if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
331  // Finished row of source.
332  currentSourceScalableVector = {};
333  }
334 
335  // 4. Increment the insert/extract indices, stepping by minExtractionSize
336  // for the trailing dimensions.
337  incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
338  incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
339  }
340 
341  rewriter.replaceOp(op, result);
342  return success();
343  }
344 
345  static bool isTrailingDimScalable(VectorType type) {
346  return type.getRank() >= 1 && type.getScalableDims().back() &&
347  !llvm::is_contained(type.getScalableDims().drop_back(), true);
348  }
349 };
350 
351 } // namespace
352 
354  RewritePatternSet &patterns, PatternBenefit benefit) {
355  patterns.add<ShapeCastOp2DDownCastRewritePattern,
356  ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
357  ScalableShapeCastOpRewritePattern>(patterns.getContext(),
358  benefit);
359 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362