MLIR  16.0.0git
VectorMultiDimReductionTransforms.cpp
Go to the documentation of this file.
1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/TypeUtilities.h"
19 
20 #define DEBUG_TYPE "vector-multi-reduction"
21 
22 using namespace mlir;
23 
24 /// This file implements the following transformations as composable atomic
25 /// patterns.
26 
27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
28 /// by using vector.transpose
30  : public OpRewritePattern<vector::MultiDimReductionOp> {
31 public:
33 
36  : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
37  useInnerDimsForReduction(
38  options == vector::VectorMultiReductionLowering::InnerReduction) {}
39 
40  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
41  PatternRewriter &rewriter) const override {
42  auto src = multiReductionOp.getSource();
43  auto loc = multiReductionOp.getLoc();
44  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
45 
46  // Separate reduction and parallel dims
47  auto reductionDimsRange =
48  multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
49  auto reductionDims = llvm::to_vector<4>(llvm::map_range(
50  reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
51  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
52  reductionDims.end());
53  int64_t reductionSize = reductionDims.size();
54  SmallVector<int64_t, 4> parallelDims;
55  for (int64_t i = 0; i < srcRank; ++i)
56  if (!reductionDimsSet.contains(i))
57  parallelDims.push_back(i);
58 
59  // Add transpose only if inner-most/outer-most dimensions are not parallel
60  // and there are parallel dims.
61  if (parallelDims.empty())
62  return failure();
63  if (useInnerDimsForReduction &&
64  (parallelDims ==
65  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
66  return failure();
67 
68  if (!useInnerDimsForReduction &&
69  (parallelDims !=
70  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
71  return failure();
72 
74  if (useInnerDimsForReduction) {
75  indices.append(parallelDims.begin(), parallelDims.end());
76  indices.append(reductionDims.begin(), reductionDims.end());
77  } else {
78  indices.append(reductionDims.begin(), reductionDims.end());
79  indices.append(parallelDims.begin(), parallelDims.end());
80  }
81  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
82  SmallVector<bool> reductionMask(srcRank, false);
83  for (int i = 0; i < reductionSize; ++i) {
84  if (useInnerDimsForReduction)
85  reductionMask[srcRank - i - 1] = true;
86  else
87  reductionMask[i] = true;
88  }
89  rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90  multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
91  reductionMask, multiReductionOp.getKind());
92  return success();
93  }
94 
95 private:
96  const bool useInnerDimsForReduction;
97 };
98 
99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
100 /// dimensions are either inner most or outer most.
102  : public OpRewritePattern<vector::MultiDimReductionOp> {
103 public:
105 
108  : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
109  useInnerDimsForReduction(
110  options == vector::VectorMultiReductionLowering::InnerReduction) {}
111 
112  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
113  PatternRewriter &rewriter) const override {
114  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
115  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
116  auto loc = multiReductionOp.getLoc();
117 
118  // If rank less than 2, nothing to do.
119  if (srcRank < 2)
120  return failure();
121 
122  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
123  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
124  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
125  return failure();
126 
127  // 1. Separate reduction and parallel dims.
128  SmallVector<int64_t, 4> parallelDims, parallelShapes;
129  SmallVector<int64_t, 4> reductionDims, reductionShapes;
130  for (const auto &it : llvm::enumerate(reductionMask)) {
131  int64_t i = it.index();
132  bool isReduction = it.value();
133  if (isReduction) {
134  reductionDims.push_back(i);
135  reductionShapes.push_back(srcShape[i]);
136  } else {
137  parallelDims.push_back(i);
138  parallelShapes.push_back(srcShape[i]);
139  }
140  }
141 
142  // 2. Compute flattened parallel and reduction sizes.
143  int flattenedParallelDim = 0;
144  int flattenedReductionDim = 0;
145  if (!parallelShapes.empty()) {
146  flattenedParallelDim = 1;
147  for (auto d : parallelShapes)
148  flattenedParallelDim *= d;
149  }
150  if (!reductionShapes.empty()) {
151  flattenedReductionDim = 1;
152  for (auto d : reductionShapes)
153  flattenedReductionDim *= d;
154  }
155  // We must at least have some parallel or some reduction.
156  assert((flattenedParallelDim || flattenedReductionDim) &&
157  "expected at least one parallel or reduction dim");
158 
159  // 3. Fail if reduction/parallel dims are not contiguous.
160  // Check parallelDims are exactly [0 .. size).
161  int64_t counter = 0;
162  if (useInnerDimsForReduction &&
163  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
164  return failure();
165  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
166  counter = reductionDims.size();
167  if (!useInnerDimsForReduction &&
168  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
169  return failure();
170 
171  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
172  // a single parallel (resp. reduction) dim.
175  if (flattenedParallelDim) {
176  mask.push_back(false);
177  vectorShape.push_back(flattenedParallelDim);
178  }
179  if (flattenedReductionDim) {
180  mask.push_back(true);
181  vectorShape.push_back(flattenedReductionDim);
182  }
183  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
184  std::swap(mask.front(), mask.back());
185  std::swap(vectorShape.front(), vectorShape.back());
186  }
187  auto castedType = VectorType::get(
188  vectorShape, multiReductionOp.getSourceVectorType().getElementType());
189  Value cast = rewriter.create<vector::ShapeCastOp>(
190  loc, castedType, multiReductionOp.getSource());
191  Value acc = multiReductionOp.getAcc();
192  if (flattenedParallelDim) {
193  auto accType = VectorType::get(
194  {flattenedParallelDim},
195  multiReductionOp.getSourceVectorType().getElementType());
196  acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
197  }
198  // 5. Creates the flattened form of vector.multi_reduction with inner/outer
199  // most dim as reduction.
200  auto newOp = rewriter.create<vector::MultiDimReductionOp>(
201  loc, cast, acc, mask, multiReductionOp.getKind());
202 
203  // 6. If there are no parallel shapes, the result is a scalar.
204  // TODO: support 0-d vectors when available.
205  if (parallelShapes.empty()) {
206  rewriter.replaceOp(multiReductionOp, newOp.getDest());
207  return success();
208  }
209 
210  // 7. Creates shape cast for the output n-D -> 2-D
211  VectorType outputCastedType = VectorType::get(
212  parallelShapes,
213  multiReductionOp.getSourceVectorType().getElementType());
214  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
215  multiReductionOp, outputCastedType, newOp.getDest());
216  return success();
217  }
218 
219 private:
220  const bool useInnerDimsForReduction;
221 };
222 
223 /// Unrolls vector.multi_reduction with outermost reductions
224 /// and combines results
226  : public OpRewritePattern<vector::MultiDimReductionOp> {
228 
229  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
230  PatternRewriter &rewriter) const override {
231  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
232  // Rank-2 ["parallel", "reduce"] or bail.
233  if (srcRank != 2)
234  return failure();
235 
236  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
237  return failure();
238 
239  auto loc = multiReductionOp.getLoc();
240  ArrayRef<int64_t> srcShape =
241  multiReductionOp.getSourceVectorType().getShape();
242 
243  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
244  if (!elementType.isIntOrIndexOrFloat())
245  return failure();
246 
247  Value result = multiReductionOp.getAcc();
248  for (int64_t i = 0; i < srcShape[0]; i++) {
249  auto operand = rewriter.create<vector::ExtractOp>(
250  loc, multiReductionOp.getSource(), i);
251  result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
252  operand, result);
253  }
254 
255  rewriter.replaceOp(multiReductionOp, result);
256  return success();
257  }
258 };
259 
260 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
261 /// a sequence of vector.reduction ops.
263  : public OpRewritePattern<vector::MultiDimReductionOp> {
265 
266  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
267  PatternRewriter &rewriter) const override {
268  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
269  if (srcRank != 2)
270  return failure();
271 
272  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
273  return failure();
274 
275  auto loc = multiReductionOp.getLoc();
276  Value result = rewriter.create<arith::ConstantOp>(
277  loc, multiReductionOp.getDestType(),
278  rewriter.getZeroAttr(multiReductionOp.getDestType()));
279  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
280 
281  for (int i = 0; i < outerDim; ++i) {
282  auto v = rewriter.create<vector::ExtractOp>(
283  loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
284  auto acc = rewriter.create<vector::ExtractOp>(
285  loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
286  auto reducedValue = rewriter.create<vector::ReductionOp>(
287  loc, multiReductionOp.getKind(), v, acc);
288  result = rewriter.create<vector::InsertElementOp>(
289  loc, reducedValue, result,
290  rewriter.create<arith::ConstantIndexOp>(loc, i));
291  }
292  rewriter.replaceOp(multiReductionOp, result);
293  return success();
294  }
295 };
296 
297 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
298 /// form with both a single parallel and reduction dimension.
299 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
300 /// The case with a single parallel dimension is a noop and folds away
301 /// separately.
303  : public OpRewritePattern<vector::MultiDimReductionOp> {
305 
306  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
307  PatternRewriter &rewriter) const override {
308  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
309  // Rank-1 or bail.
310  if (srcRank != 1)
311  return failure();
312 
313  auto loc = multiReductionOp.getLoc();
314  auto srcVectorType = multiReductionOp.getSourceVectorType();
315  auto srcShape = srcVectorType.getShape();
316  auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
317  srcVectorType.getElementType());
318  auto accType =
319  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
320  assert(!multiReductionOp.getDestType().isa<VectorType>() &&
321  "multi_reduction with a single dimension expects a scalar result");
322 
323  // If the unique dim is reduced and we insert a parallel in front, we need a
324  // {false, true} mask.
325  SmallVector<bool, 2> mask{false, true};
326 
327  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
328  Value cast = rewriter.create<vector::ShapeCastOp>(
329  loc, castedType, multiReductionOp.getSource());
330  Value castAcc = rewriter.create<vector::BroadcastOp>(
331  loc, accType, multiReductionOp.getAcc());
332  Value reduced = rewriter.create<vector::MultiDimReductionOp>(
333  loc, cast, castAcc, mask, multiReductionOp.getKind());
334  rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
335  ArrayRef<int64_t>{0});
336  return success();
337  }
338 };
339 
343  patterns.getContext(), options);
344  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
345  if (options == VectorMultiReductionLowering ::InnerReduction)
346  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
347  else
348  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
349 }
Include the generated interface declarations.
Reduces the rank of vector.multi_reduction nd -> 2d given all reduction dimensions are either inner m...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
ReduceMultiDimReductionRank(MLIRContext *context, vector::VectorMultiReductionLowering options)
static ArrayRef< int64_t > vectorShape(Type type)
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
VectorMultiReductionLowering
Enum to control the lowering of vector.multi_reduction operations.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Converts 1d vector.multi_reduction with a single reduction dimension to a 2d form with both a single ...
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Unrolls vector.multi_reduction with outermost reductions and combines results.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
This file implements the following transformations as composable atomic patterns. ...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
static llvm::ManagedStatic< PassManagerOptions > options
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
InnerOuterDimReductionConversion(MLIRContext *context, vector::VectorMultiReductionLowering options)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Converts 2d vector.multi_reduction with inner most reduction dimension into a sequence of vector...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
MLIRContext * getContext() const