MLIR  14.0.0git
VectorMultiDimReductionTransforms.cpp
Go to the documentation of this file.
1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #define DEBUG_TYPE "vector-multi-reduction"
20 
21 using namespace mlir;
22 
23 /// This file implements the following transformations as composable atomic
24 /// patterns.
25 
26 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
27 /// by using vector.transpose
29  : public OpRewritePattern<vector::MultiDimReductionOp> {
30 public:
32 
35  : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
36  useInnerDimsForReduction(
37  options == vector::VectorMultiReductionLowering::InnerReduction) {}
38 
39  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
40  PatternRewriter &rewriter) const override {
41  auto src = multiReductionOp.source();
42  auto loc = multiReductionOp.getLoc();
43  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
44 
45  // Separate reduction and parallel dims
46  auto reductionDimsRange =
47  multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
48  auto reductionDims = llvm::to_vector<4>(llvm::map_range(
49  reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
50  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
51  reductionDims.end());
52  int64_t reductionSize = reductionDims.size();
53  SmallVector<int64_t, 4> parallelDims;
54  for (int64_t i = 0; i < srcRank; ++i)
55  if (!reductionDimsSet.contains(i))
56  parallelDims.push_back(i);
57 
58  // Add transpose only if inner-most/outer-most dimensions are not parallel
59  // and there are parallel dims.
60  if (parallelDims.empty())
61  return failure();
62  if (useInnerDimsForReduction &&
63  (parallelDims ==
64  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
65  return failure();
66 
67  if (!useInnerDimsForReduction &&
68  (parallelDims !=
69  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
70  return failure();
71 
73  if (useInnerDimsForReduction) {
74  indices.append(parallelDims.begin(), parallelDims.end());
75  indices.append(reductionDims.begin(), reductionDims.end());
76  } else {
77  indices.append(reductionDims.begin(), reductionDims.end());
78  indices.append(parallelDims.begin(), parallelDims.end());
79  }
80  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
81  SmallVector<bool> reductionMask(srcRank, false);
82  for (int i = 0; i < reductionSize; ++i) {
83  if (useInnerDimsForReduction)
84  reductionMask[srcRank - i - 1] = true;
85  else
86  reductionMask[i] = true;
87  }
88  rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
89  multiReductionOp, transposeOp.result(), reductionMask,
90  multiReductionOp.kind());
91  return success();
92  }
93 
94 private:
95  const bool useInnerDimsForReduction;
96 };
97 
98 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
99 /// dimensions are either inner most or outer most.
101  : public OpRewritePattern<vector::MultiDimReductionOp> {
102 public:
104 
107  : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
108  useInnerDimsForReduction(
109  options == vector::VectorMultiReductionLowering::InnerReduction) {}
110 
111  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
112  PatternRewriter &rewriter) const override {
113  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
114  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
115  auto loc = multiReductionOp.getLoc();
116 
117  // If rank less than 2, nothing to do.
118  if (srcRank < 2)
119  return failure();
120 
121  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
122  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
123  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
124  return failure();
125 
126  // 1. Separate reduction and parallel dims.
127  SmallVector<int64_t, 4> parallelDims, parallelShapes;
128  SmallVector<int64_t, 4> reductionDims, reductionShapes;
129  for (const auto &it : llvm::enumerate(reductionMask)) {
130  int64_t i = it.index();
131  bool isReduction = it.value();
132  if (isReduction) {
133  reductionDims.push_back(i);
134  reductionShapes.push_back(srcShape[i]);
135  } else {
136  parallelDims.push_back(i);
137  parallelShapes.push_back(srcShape[i]);
138  }
139  }
140 
141  // 2. Compute flattened parallel and reduction sizes.
142  int flattenedParallelDim = 0;
143  int flattenedReductionDim = 0;
144  if (!parallelShapes.empty()) {
145  flattenedParallelDim = 1;
146  for (auto d : parallelShapes)
147  flattenedParallelDim *= d;
148  }
149  if (!reductionShapes.empty()) {
150  flattenedReductionDim = 1;
151  for (auto d : reductionShapes)
152  flattenedReductionDim *= d;
153  }
154  // We must at least have some parallel or some reduction.
155  assert((flattenedParallelDim || flattenedReductionDim) &&
156  "expected at least one parallel or reduction dim");
157 
158  // 3. Fail if reduction/parallel dims are not contiguous.
159  // Check parallelDims are exactly [0 .. size).
160  int64_t counter = 0;
161  if (useInnerDimsForReduction &&
162  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
163  return failure();
164  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
165  counter = reductionDims.size();
166  if (!useInnerDimsForReduction &&
167  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
168  return failure();
169 
170  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
171  // a single parallel (resp. reduction) dim.
174  if (flattenedParallelDim) {
175  mask.push_back(false);
176  vectorShape.push_back(flattenedParallelDim);
177  }
178  if (flattenedReductionDim) {
179  mask.push_back(true);
180  vectorShape.push_back(flattenedReductionDim);
181  }
182  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
183  std::swap(mask.front(), mask.back());
184  std::swap(vectorShape.front(), vectorShape.back());
185  }
186  auto castedType = VectorType::get(
187  vectorShape, multiReductionOp.getSourceVectorType().getElementType());
188  Value cast = rewriter.create<vector::ShapeCastOp>(
189  loc, castedType, multiReductionOp.source());
190 
191  // 5. Creates the flattened form of vector.multi_reduction with inner/outer
192  // most dim as reduction.
193  auto newOp = rewriter.create<vector::MultiDimReductionOp>(
194  loc, cast, mask, multiReductionOp.kind());
195 
196  // 6. If there are no parallel shapes, the result is a scalar.
197  // TODO: support 0-d vectors when available.
198  if (parallelShapes.empty()) {
199  rewriter.replaceOp(multiReductionOp, newOp.dest());
200  return success();
201  }
202 
203  // 7. Creates shape cast for the output n-D -> 2-D
204  VectorType outputCastedType = VectorType::get(
205  parallelShapes,
206  multiReductionOp.getSourceVectorType().getElementType());
207  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
208  multiReductionOp, outputCastedType, newOp.dest());
209  return success();
210  }
211 
212 private:
213  const bool useInnerDimsForReduction;
214 };
215 
216 /// Unrolls vector.multi_reduction with outermost reductions
217 /// and combines results
219  : public OpRewritePattern<vector::MultiDimReductionOp> {
221 
222  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
223  PatternRewriter &rewriter) const override {
224  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
225  // Rank-2 ["parallel", "reduce"] or bail.
226  if (srcRank != 2)
227  return failure();
228 
229  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
230  return failure();
231 
232  auto loc = multiReductionOp.getLoc();
233  ArrayRef<int64_t> srcShape =
234  multiReductionOp.getSourceVectorType().getShape();
235 
236  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
237  if (!elementType.isIntOrIndexOrFloat())
238  return failure();
239 
240  Value result =
241  rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
242  .getResult();
243  for (int64_t i = 1; i < srcShape[0]; i++) {
244  auto operand =
245  rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
246  switch (multiReductionOp.kind()) {
247  case vector::CombiningKind::ADD:
248  if (elementType.isIntOrIndex())
249  result = rewriter.create<arith::AddIOp>(loc, operand, result);
250  else
251  result = rewriter.create<arith::AddFOp>(loc, operand, result);
252  break;
253  case vector::CombiningKind::MUL:
254  if (elementType.isIntOrIndex())
255  result = rewriter.create<arith::MulIOp>(loc, operand, result);
256  else
257  result = rewriter.create<arith::MulFOp>(loc, operand, result);
258  break;
259  case vector::CombiningKind::MINUI:
260  result = rewriter.create<arith::MinUIOp>(loc, operand, result);
261  break;
262  case vector::CombiningKind::MINSI:
263  result = rewriter.create<arith::MinSIOp>(loc, operand, result);
264  break;
265  case vector::CombiningKind::MINF:
266  result = rewriter.create<arith::MinFOp>(loc, operand, result);
267  break;
268  case vector::CombiningKind::MAXUI:
269  result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
270  break;
271  case vector::CombiningKind::MAXSI:
272  result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
273  break;
274  case vector::CombiningKind::MAXF:
275  result = rewriter.create<arith::MaxFOp>(loc, operand, result);
276  break;
277  case vector::CombiningKind::AND:
278  result = rewriter.create<arith::AndIOp>(loc, operand, result);
279  break;
280  case vector::CombiningKind::OR:
281  result = rewriter.create<arith::OrIOp>(loc, operand, result);
282  break;
283  case vector::CombiningKind::XOR:
284  result = rewriter.create<arith::XOrIOp>(loc, operand, result);
285  break;
286  }
287  }
288 
289  rewriter.replaceOp(multiReductionOp, result);
290  return success();
291  }
292 };
293 
294 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
295 /// a sequence of vector.reduction ops.
297  : public OpRewritePattern<vector::MultiDimReductionOp> {
299 
300  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
301  PatternRewriter &rewriter) const override {
302  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
303  if (srcRank != 2)
304  return failure();
305 
306  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
307  return failure();
308 
309  auto loc = multiReductionOp.getLoc();
310  Value result = rewriter.create<ConstantOp>(
311  loc, multiReductionOp.getDestType(),
312  rewriter.getZeroAttr(multiReductionOp.getDestType()));
313  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
314 
315  // TODO: Add vector::CombiningKind attribute instead of string to
316  // vector.reduction.
317  auto getKindStr = [](vector::CombiningKind kind) {
318  switch (kind) {
319  case vector::CombiningKind::ADD:
320  return "add";
321  case vector::CombiningKind::MUL:
322  return "mul";
323  case vector::CombiningKind::MINUI:
324  return "minui";
325  case vector::CombiningKind::MINSI:
326  return "minsi";
327  case vector::CombiningKind::MINF:
328  return "minf";
329  case vector::CombiningKind::MAXUI:
330  return "maxui";
331  case vector::CombiningKind::MAXSI:
332  return "maxsi";
333  case vector::CombiningKind::MAXF:
334  return "maxf";
335  case vector::CombiningKind::AND:
336  return "and";
337  case vector::CombiningKind::OR:
338  return "or";
339  case vector::CombiningKind::XOR:
340  return "xor";
341  }
342  llvm_unreachable("unknown combining kind");
343  };
344 
345  for (int i = 0; i < outerDim; ++i) {
346  auto v = rewriter.create<vector::ExtractOp>(
347  loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
348  auto reducedValue = rewriter.create<vector::ReductionOp>(
349  loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
350  rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
351  ValueRange{});
352  result = rewriter.create<vector::InsertElementOp>(
353  loc, reducedValue, result,
354  rewriter.create<arith::ConstantIndexOp>(loc, i));
355  }
356  rewriter.replaceOp(multiReductionOp, result);
357  return success();
358  }
359 };
360 
361 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
362 /// form with both a single parallel and reduction dimension.
363 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
364 /// The case with a single parallel dimension is a noop and folds away
365 /// separately.
367  : public OpRewritePattern<vector::MultiDimReductionOp> {
369 
370  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
371  PatternRewriter &rewriter) const override {
372  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
373  // Rank-1 or bail.
374  if (srcRank != 1)
375  return failure();
376 
377  auto loc = multiReductionOp.getLoc();
378  auto srcVectorType = multiReductionOp.getSourceVectorType();
379  auto srcShape = srcVectorType.getShape();
380  auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
381  srcVectorType.getElementType());
382  assert(!multiReductionOp.getDestType().isa<VectorType>() &&
383  "multi_reduction with a single dimension expects a scalar result");
384 
385  // If the unique dim is reduced and we insert a parallel in front, we need a
386  // {false, true} mask.
387  SmallVector<bool, 2> mask{false, true};
388 
389  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
390  Value cast = rewriter.create<vector::ShapeCastOp>(
391  loc, castedType, multiReductionOp.source());
392  Value reduced = rewriter.create<vector::MultiDimReductionOp>(
393  loc, cast, mask, multiReductionOp.kind());
394  rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
395  ArrayRef<int64_t>{0});
396  return success();
397  }
398 };
399 
403  patterns.getContext(), options);
404  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
405  if (options == VectorMultiReductionLowering ::InnerReduction)
406  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
407  else
408  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
409 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Reduces the rank of vector.multi_reduction nd -> 2d given all reduction dimensions are either inner m...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
ReduceMultiDimReductionRank(MLIRContext *context, vector::VectorMultiReductionLowering options)
static ArrayRef< int64_t > vectorShape(Type type)
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
VectorMultiReductionLowering
Enum to control the lowering of vector.multi_reduction operations.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Converts 1d vector.multi_reduction with a single reduction dimension to a 2d form with both a single ...
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Unrolls vector.multi_reduction with outermost reductions and combines results.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
This file implements the following transformations as composable atomic patterns. ...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
static llvm::ManagedStatic< PassManagerOptions > options
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
InnerOuterDimReductionConversion(MLIRContext *context, vector::VectorMultiReductionLowering options)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:85
Converts 2d vector.multi_reduction with inner most reduction dimension into a sequence of vector...
This class provides an abstraction over the different types of ranges over Values.
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
MLIRContext * getContext() const
Definition: PatternMatch.h:906