MLIR  19.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #define DEBUG_TYPE "vector-multi-reduction"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// This file implements the following transformations as composable atomic
25 /// patterns.
26 
27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
28 /// by using vector.transpose
29 class InnerOuterDimReductionConversion
30  : public OpRewritePattern<vector::MultiDimReductionOp> {
31 public:
33 
34  explicit InnerOuterDimReductionConversion(
35  MLIRContext *context, vector::VectorMultiReductionLowering options,
36  PatternBenefit benefit = 1)
38  useInnerDimsForReduction(
39  options == vector::VectorMultiReductionLowering::InnerReduction) {}
40 
41  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
42  PatternRewriter &rewriter) const override {
43  // Vector mask setup.
44  OpBuilder::InsertionGuard guard(rewriter);
45  auto maskableOp =
46  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
47  Operation *rootOp;
48  if (maskableOp.isMasked()) {
49  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
50  rootOp = maskableOp.getMaskingOp();
51  } else {
52  rootOp = multiReductionOp;
53  }
54 
55  auto src = multiReductionOp.getSource();
56  auto loc = multiReductionOp.getLoc();
57  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
58 
59  // Separate reduction and parallel dims
60  auto reductionDimsRange =
61  multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
62  auto reductionDims = llvm::to_vector<4>(llvm::map_range(
63  reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
64  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
65  reductionDims.end());
66  int64_t reductionSize = reductionDims.size();
67  SmallVector<int64_t, 4> parallelDims;
68  for (int64_t i = 0; i < srcRank; ++i)
69  if (!reductionDimsSet.contains(i))
70  parallelDims.push_back(i);
71 
72  // Add transpose only if inner-most/outer-most dimensions are not parallel
73  // and there are parallel dims.
74  if (parallelDims.empty())
75  return failure();
76  if (useInnerDimsForReduction &&
77  (parallelDims ==
78  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
79  return failure();
80 
81  if (!useInnerDimsForReduction &&
82  (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
83  reductionDims.size(),
84  parallelDims.size() + reductionDims.size()))))
85  return failure();
86 
88  if (useInnerDimsForReduction) {
89  indices.append(parallelDims.begin(), parallelDims.end());
90  indices.append(reductionDims.begin(), reductionDims.end());
91  } else {
92  indices.append(reductionDims.begin(), reductionDims.end());
93  indices.append(parallelDims.begin(), parallelDims.end());
94  }
95 
96  // If masked, transpose the original mask.
97  Value transposedMask;
98  if (maskableOp.isMasked()) {
99  transposedMask = rewriter.create<vector::TransposeOp>(
100  loc, maskableOp.getMaskingOp().getMask(), indices);
101  }
102 
103  // Transpose reduction source.
104  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
105  SmallVector<bool> reductionMask(srcRank, false);
106  for (int i = 0; i < reductionSize; ++i) {
107  if (useInnerDimsForReduction)
108  reductionMask[srcRank - i - 1] = true;
109  else
110  reductionMask[i] = true;
111  }
112 
113  Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
114  multiReductionOp.getLoc(), transposeOp.getResult(),
115  multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
116  newMultiRedOp =
117  mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
118 
119  rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
120  return success();
121  }
122 
123 private:
124  const bool useInnerDimsForReduction;
125 };
126 
127 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
128 /// dimensions are either inner most or outer most.
129 class ReduceMultiDimReductionRank
130  : public OpRewritePattern<vector::MultiDimReductionOp> {
131 public:
133 
134  explicit ReduceMultiDimReductionRank(
135  MLIRContext *context, vector::VectorMultiReductionLowering options,
136  PatternBenefit benefit = 1)
138  useInnerDimsForReduction(
139  options == vector::VectorMultiReductionLowering::InnerReduction) {}
140 
141  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
142  PatternRewriter &rewriter) const override {
143  // Vector mask setup.
144  OpBuilder::InsertionGuard guard(rewriter);
145  auto maskableOp =
146  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
147  Operation *rootOp;
148  if (maskableOp.isMasked()) {
149  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
150  rootOp = maskableOp.getMaskingOp();
151  } else {
152  rootOp = multiReductionOp;
153  }
154 
155  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
156  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
157  auto srcScalableDims =
158  multiReductionOp.getSourceVectorType().getScalableDims();
159  auto loc = multiReductionOp.getLoc();
160 
161  // If rank less than 2, nothing to do.
162  if (srcRank < 2)
163  return failure();
164 
165  // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
166  // `vscale * vscale` that's currently not modelled.
167  if (llvm::count(srcScalableDims, true) > 1)
168  return failure();
169 
170  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
171  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
172  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
173  return failure();
174 
175  // 1. Separate reduction and parallel dims.
176  SmallVector<int64_t, 4> parallelDims, parallelShapes;
177  SmallVector<bool, 4> parallelScalableDims;
178  SmallVector<int64_t, 4> reductionDims, reductionShapes;
179  bool isReductionDimScalable = false;
180  for (const auto &it : llvm::enumerate(reductionMask)) {
181  int64_t i = it.index();
182  bool isReduction = it.value();
183  if (isReduction) {
184  reductionDims.push_back(i);
185  reductionShapes.push_back(srcShape[i]);
186  isReductionDimScalable |= srcScalableDims[i];
187  } else {
188  parallelDims.push_back(i);
189  parallelShapes.push_back(srcShape[i]);
190  parallelScalableDims.push_back(srcScalableDims[i]);
191  }
192  }
193 
194  // 2. Compute flattened parallel and reduction sizes.
195  int flattenedParallelDim = 0;
196  int flattenedReductionDim = 0;
197  if (!parallelShapes.empty()) {
198  flattenedParallelDim = 1;
199  for (auto d : parallelShapes)
200  flattenedParallelDim *= d;
201  }
202  if (!reductionShapes.empty()) {
203  flattenedReductionDim = 1;
204  for (auto d : reductionShapes)
205  flattenedReductionDim *= d;
206  }
207  // We must at least have some parallel or some reduction.
208  assert((flattenedParallelDim || flattenedReductionDim) &&
209  "expected at least one parallel or reduction dim");
210 
211  // 3. Fail if reduction/parallel dims are not contiguous.
212  // Check parallelDims are exactly [0 .. size).
213  int64_t counter = 0;
214  if (useInnerDimsForReduction &&
215  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
216  return failure();
217  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
218  counter = reductionDims.size();
219  if (!useInnerDimsForReduction &&
220  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
221  return failure();
222 
223  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
224  // a single parallel (resp. reduction) dim.
226  SmallVector<bool, 2> scalableDims;
228  bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
229  if (flattenedParallelDim) {
230  mask.push_back(false);
231  vectorShape.push_back(flattenedParallelDim);
232  scalableDims.push_back(isParallelDimScalable);
233  }
234  if (flattenedReductionDim) {
235  mask.push_back(true);
236  vectorShape.push_back(flattenedReductionDim);
237  scalableDims.push_back(isReductionDimScalable);
238  }
239  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
240  std::swap(mask.front(), mask.back());
241  std::swap(vectorShape.front(), vectorShape.back());
242  std::swap(scalableDims.front(), scalableDims.back());
243  }
244 
245  Value newVectorMask;
246  if (maskableOp.isMasked()) {
247  Value vectorMask = maskableOp.getMaskingOp().getMask();
248  auto maskCastedType = VectorType::get(
249  vectorShape,
250  llvm::cast<VectorType>(vectorMask.getType()).getElementType());
251  newVectorMask =
252  rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
253  }
254 
255  auto castedType = VectorType::get(
256  vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
257  scalableDims);
258  Value cast = rewriter.create<vector::ShapeCastOp>(
259  loc, castedType, multiReductionOp.getSource());
260 
261  Value acc = multiReductionOp.getAcc();
262  if (flattenedParallelDim) {
263  auto accType = VectorType::get(
264  {flattenedParallelDim},
265  multiReductionOp.getSourceVectorType().getElementType(),
266  /*scalableDims=*/{isParallelDimScalable});
267  acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
268  }
269  // 6. Creates the flattened form of vector.multi_reduction with inner/outer
270  // most dim as reduction.
271  Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
272  loc, cast, acc, mask, multiReductionOp.getKind());
273  newMultiDimRedOp =
274  mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
275 
276  // 7. If there are no parallel shapes, the result is a scalar.
277  // TODO: support 0-d vectors when available.
278  if (parallelShapes.empty()) {
279  rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
280  return success();
281  }
282 
283  // 8. Creates shape cast for the output n-D -> 2-D.
284  VectorType outputCastedType = VectorType::get(
285  parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
286  parallelScalableDims);
287  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
288  rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
289  return success();
290  }
291 
292 private:
293  const bool useInnerDimsForReduction;
294 };
295 
296 /// Unrolls vector.multi_reduction with outermost reductions
297 /// and combines results
298 struct TwoDimMultiReductionToElementWise
299  : public OpRewritePattern<vector::MultiDimReductionOp> {
301 
302  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
303  PatternRewriter &rewriter) const override {
304  auto maskableOp =
305  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
306  if (maskableOp.isMasked())
307  // TODO: Support masking.
308  return failure();
309 
310  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
311  // Rank-2 ["parallel", "reduce"] or bail.
312  if (srcRank != 2)
313  return failure();
314 
315  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
316  return failure();
317 
318  auto loc = multiReductionOp.getLoc();
319  ArrayRef<int64_t> srcShape =
320  multiReductionOp.getSourceVectorType().getShape();
321 
322  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
323  if (!elementType.isIntOrIndexOrFloat())
324  return failure();
325 
326  Value result = multiReductionOp.getAcc();
327  for (int64_t i = 0; i < srcShape[0]; i++) {
328  auto operand = rewriter.create<vector::ExtractOp>(
329  loc, multiReductionOp.getSource(), i);
330  result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
331  operand, result);
332  }
333 
334  rewriter.replaceOp(multiReductionOp, result);
335  return success();
336  }
337 };
338 
339 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
340 /// a sequence of vector.reduction ops.
341 struct TwoDimMultiReductionToReduction
342  : public OpRewritePattern<vector::MultiDimReductionOp> {
344 
345  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
346  PatternRewriter &rewriter) const override {
347  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
348  if (srcRank != 2)
349  return failure();
350 
351  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
352  return failure();
353 
354  // Vector mask setup.
355  OpBuilder::InsertionGuard guard(rewriter);
356  auto maskableOp =
357  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
358  Operation *rootOp;
359  if (maskableOp.isMasked()) {
360  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
361  rootOp = maskableOp.getMaskingOp();
362  } else {
363  rootOp = multiReductionOp;
364  }
365 
366  auto loc = multiReductionOp.getLoc();
367  Value result = rewriter.create<arith::ConstantOp>(
368  loc, multiReductionOp.getDestType(),
369  rewriter.getZeroAttr(multiReductionOp.getDestType()));
370  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
371 
372  for (int i = 0; i < outerDim; ++i) {
373  auto v = rewriter.create<vector::ExtractOp>(
374  loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
375  auto acc = rewriter.create<vector::ExtractOp>(
376  loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
377  Operation *reductionOp = rewriter.create<vector::ReductionOp>(
378  loc, multiReductionOp.getKind(), v, acc);
379 
380  // If masked, slice the mask and mask the new reduction operation.
381  if (maskableOp.isMasked()) {
382  Value mask = rewriter.create<vector::ExtractOp>(
383  loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
384  reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
385  }
386 
387  result = rewriter.create<vector::InsertElementOp>(
388  loc, reductionOp->getResult(0), result,
389  rewriter.create<arith::ConstantIndexOp>(loc, i));
390  }
391 
392  rewriter.replaceOp(rootOp, result);
393  return success();
394  }
395 };
396 
397 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
398 /// form with both a single parallel and reduction dimension.
399 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
400 /// The case with a single parallel dimension is a noop and folds away
401 /// separately.
402 struct OneDimMultiReductionToTwoDim
403  : public OpRewritePattern<vector::MultiDimReductionOp> {
405 
406  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
407  PatternRewriter &rewriter) const override {
408  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
409  // Rank-1 or bail.
410  if (srcRank != 1)
411  return failure();
412 
413  // Vector mask setup.
414  OpBuilder::InsertionGuard guard(rewriter);
415  auto maskableOp =
416  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
417  Operation *rootOp;
418  Value mask;
419  if (maskableOp.isMasked()) {
420  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
421  rootOp = maskableOp.getMaskingOp();
422  mask = maskableOp.getMaskingOp().getMask();
423  } else {
424  rootOp = multiReductionOp;
425  }
426 
427  auto loc = multiReductionOp.getLoc();
428  auto srcVectorType = multiReductionOp.getSourceVectorType();
429  auto srcShape = srcVectorType.getShape();
430  auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
431  srcVectorType.getElementType());
432  auto accType =
433  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
434  assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
435  "multi_reduction with a single dimension expects a scalar result");
436 
437  // If the unique dim is reduced and we insert a parallel in front, we need a
438  // {false, true} mask.
439  SmallVector<bool, 2> reductionMask{false, true};
440 
441  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
442  Value cast = rewriter.create<vector::ShapeCastOp>(
443  loc, castedType, multiReductionOp.getSource());
444  Value castAcc = rewriter.create<vector::BroadcastOp>(
445  loc, accType, multiReductionOp.getAcc());
446  Value castMask;
447  if (maskableOp.isMasked()) {
448  auto maskType = llvm::cast<ShapedType>(mask.getType());
449  auto castMaskType =
450  VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
451  maskType.getElementType());
452  castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
453  }
454 
455  Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
456  loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
457  newOp = vector::maskOperation(rewriter, newOp, castMask);
458 
459  rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
460  ArrayRef<int64_t>{0});
461  return success();
462  }
463 };
464 } // namespace
465 
467  RewritePatternSet &patterns, VectorMultiReductionLowering options,
468  PatternBenefit benefit) {
469  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
470  patterns.getContext(), options, benefit);
471  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
472  if (options == VectorMultiReductionLowering ::InnerReduction)
473  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
474  benefit);
475  else
476  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
477  benefit);
478 }
static llvm::ManagedStatic< PassManagerOptions > options
static VectorShape vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361