MLIR  21.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
21 
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 } // namespace vector
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "vector-multi-reduction"
30 
31 using namespace mlir;
32 
33 namespace {
34 /// This file implements the following transformations as composable atomic
35 /// patterns.
36 
37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
38 /// by using vector.transpose
39 class InnerOuterDimReductionConversion
40  : public OpRewritePattern<vector::MultiDimReductionOp> {
41 public:
43 
44  explicit InnerOuterDimReductionConversion(
45  MLIRContext *context, vector::VectorMultiReductionLowering options,
46  PatternBenefit benefit = 1)
48  useInnerDimsForReduction(
49  options == vector::VectorMultiReductionLowering::InnerReduction) {}
50 
51  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52  PatternRewriter &rewriter) const override {
53  // Vector mask setup.
54  OpBuilder::InsertionGuard guard(rewriter);
55  auto maskableOp =
56  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57  Operation *rootOp;
58  if (maskableOp.isMasked()) {
59  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60  rootOp = maskableOp.getMaskingOp();
61  } else {
62  rootOp = multiReductionOp;
63  }
64 
65  auto src = multiReductionOp.getSource();
66  auto loc = multiReductionOp.getLoc();
67  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68 
69  // Separate reduction and parallel dims
70  ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72  reductionDims.end());
73  int64_t reductionSize = reductionDims.size();
74  SmallVector<int64_t, 4> parallelDims;
75  for (int64_t i = 0; i < srcRank; ++i)
76  if (!reductionDimsSet.contains(i))
77  parallelDims.push_back(i);
78 
79  // Add transpose only if inner-most/outer-most dimensions are not parallel
80  // and there are parallel dims.
81  if (parallelDims.empty())
82  return failure();
83  if (useInnerDimsForReduction &&
84  (parallelDims ==
85  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86  return failure();
87 
88  if (!useInnerDimsForReduction &&
89  (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90  reductionDims.size(),
91  parallelDims.size() + reductionDims.size()))))
92  return failure();
93 
95  if (useInnerDimsForReduction) {
96  indices.append(parallelDims.begin(), parallelDims.end());
97  indices.append(reductionDims.begin(), reductionDims.end());
98  } else {
99  indices.append(reductionDims.begin(), reductionDims.end());
100  indices.append(parallelDims.begin(), parallelDims.end());
101  }
102 
103  // If masked, transpose the original mask.
104  Value transposedMask;
105  if (maskableOp.isMasked()) {
106  transposedMask = rewriter.create<vector::TransposeOp>(
107  loc, maskableOp.getMaskingOp().getMask(), indices);
108  }
109 
110  // Transpose reduction source.
111  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
112  SmallVector<bool> reductionMask(srcRank, false);
113  for (int i = 0; i < reductionSize; ++i) {
114  if (useInnerDimsForReduction)
115  reductionMask[srcRank - i - 1] = true;
116  else
117  reductionMask[i] = true;
118  }
119 
120  Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
121  multiReductionOp.getLoc(), transposeOp.getResult(),
122  multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123  newMultiRedOp =
124  mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125 
126  rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127  return success();
128  }
129 
130 private:
131  const bool useInnerDimsForReduction;
132 };
133 
134 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135 /// dimensions are either inner most or outer most.
136 class ReduceMultiDimReductionRank
137  : public OpRewritePattern<vector::MultiDimReductionOp> {
138 public:
140 
141  explicit ReduceMultiDimReductionRank(
142  MLIRContext *context, vector::VectorMultiReductionLowering options,
143  PatternBenefit benefit = 1)
145  useInnerDimsForReduction(
146  options == vector::VectorMultiReductionLowering::InnerReduction) {}
147 
148  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149  PatternRewriter &rewriter) const override {
150  // Vector mask setup.
151  OpBuilder::InsertionGuard guard(rewriter);
152  auto maskableOp =
153  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154  Operation *rootOp;
155  if (maskableOp.isMasked()) {
156  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157  rootOp = maskableOp.getMaskingOp();
158  } else {
159  rootOp = multiReductionOp;
160  }
161 
162  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164  auto srcScalableDims =
165  multiReductionOp.getSourceVectorType().getScalableDims();
166  auto loc = multiReductionOp.getLoc();
167 
168  // If rank less than 2, nothing to do.
169  if (srcRank < 2)
170  return failure();
171 
172  // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173  // `vscale * vscale` that's currently not modelled.
174  if (llvm::count(srcScalableDims, true) > 1)
175  return failure();
176 
177  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180  return failure();
181 
182  // 1. Separate reduction and parallel dims.
183  SmallVector<int64_t, 4> parallelDims, parallelShapes;
184  SmallVector<bool, 4> parallelScalableDims;
185  SmallVector<int64_t, 4> reductionDims, reductionShapes;
186  bool isReductionDimScalable = false;
187  for (const auto &it : llvm::enumerate(reductionMask)) {
188  int64_t i = it.index();
189  bool isReduction = it.value();
190  if (isReduction) {
191  reductionDims.push_back(i);
192  reductionShapes.push_back(srcShape[i]);
193  isReductionDimScalable |= srcScalableDims[i];
194  } else {
195  parallelDims.push_back(i);
196  parallelShapes.push_back(srcShape[i]);
197  parallelScalableDims.push_back(srcScalableDims[i]);
198  }
199  }
200 
201  // 2. Compute flattened parallel and reduction sizes.
202  int flattenedParallelDim = 0;
203  int flattenedReductionDim = 0;
204  if (!parallelShapes.empty()) {
205  flattenedParallelDim = 1;
206  for (auto d : parallelShapes)
207  flattenedParallelDim *= d;
208  }
209  if (!reductionShapes.empty()) {
210  flattenedReductionDim = 1;
211  for (auto d : reductionShapes)
212  flattenedReductionDim *= d;
213  }
214  // We must at least have some parallel or some reduction.
215  assert((flattenedParallelDim || flattenedReductionDim) &&
216  "expected at least one parallel or reduction dim");
217 
218  // 3. Fail if reduction/parallel dims are not contiguous.
219  // Check parallelDims are exactly [0 .. size).
220  int64_t counter = 0;
221  if (useInnerDimsForReduction &&
222  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223  return failure();
224  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225  counter = reductionDims.size();
226  if (!useInnerDimsForReduction &&
227  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228  return failure();
229 
230  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231  // a single parallel (resp. reduction) dim.
233  SmallVector<bool, 2> scalableDims;
235  bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236  if (flattenedParallelDim) {
237  mask.push_back(false);
238  vectorShape.push_back(flattenedParallelDim);
239  scalableDims.push_back(isParallelDimScalable);
240  }
241  if (flattenedReductionDim) {
242  mask.push_back(true);
243  vectorShape.push_back(flattenedReductionDim);
244  scalableDims.push_back(isReductionDimScalable);
245  }
246  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247  std::swap(mask.front(), mask.back());
248  std::swap(vectorShape.front(), vectorShape.back());
249  std::swap(scalableDims.front(), scalableDims.back());
250  }
251 
252  Value newVectorMask;
253  if (maskableOp.isMasked()) {
254  Value vectorMask = maskableOp.getMaskingOp().getMask();
255  auto maskCastedType = VectorType::get(
256  vectorShape,
257  llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258  newVectorMask =
259  rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
260  }
261 
262  auto castedType = VectorType::get(
263  vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264  scalableDims);
265  Value cast = rewriter.create<vector::ShapeCastOp>(
266  loc, castedType, multiReductionOp.getSource());
267 
268  Value acc = multiReductionOp.getAcc();
269  if (flattenedParallelDim) {
270  auto accType = VectorType::get(
271  {flattenedParallelDim},
272  multiReductionOp.getSourceVectorType().getElementType(),
273  /*scalableDims=*/{isParallelDimScalable});
274  acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
275  }
276  // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277  // most dim as reduction.
278  Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
279  loc, cast, acc, mask, multiReductionOp.getKind());
280  newMultiDimRedOp =
281  mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
282 
283  // 7. If there are no parallel shapes, the result is a scalar.
284  // TODO: support 0-d vectors when available.
285  if (parallelShapes.empty()) {
286  rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
287  return success();
288  }
289 
290  // 8. Creates shape cast for the output n-D -> 2-D.
291  VectorType outputCastedType = VectorType::get(
292  parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293  parallelScalableDims);
294  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295  rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296  return success();
297  }
298 
299 private:
300  const bool useInnerDimsForReduction;
301 };
302 
303 /// Unrolls vector.multi_reduction with outermost reductions
304 /// and combines results
305 struct TwoDimMultiReductionToElementWise
306  : public OpRewritePattern<vector::MultiDimReductionOp> {
308 
309  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310  PatternRewriter &rewriter) const override {
311  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
312  // Rank-2 ["parallel", "reduce"] or bail.
313  if (srcRank != 2)
314  return failure();
315 
316  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
317  return failure();
318 
319  auto loc = multiReductionOp.getLoc();
320  ArrayRef<int64_t> srcShape =
321  multiReductionOp.getSourceVectorType().getShape();
322 
323  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
324  if (!elementType.isIntOrIndexOrFloat())
325  return failure();
326 
327  OpBuilder::InsertionGuard guard(rewriter);
328  auto maskableOp =
329  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
330  Operation *rootOp;
331  Value mask = nullptr;
332  if (maskableOp.isMasked()) {
333  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
334  rootOp = maskableOp.getMaskingOp();
335  mask = maskableOp.getMaskingOp().getMask();
336  } else {
337  rootOp = multiReductionOp;
338  }
339 
340  Value result = multiReductionOp.getAcc();
341  for (int64_t i = 0; i < srcShape[0]; i++) {
342  auto operand = rewriter.create<vector::ExtractOp>(
343  loc, multiReductionOp.getSource(), i);
344  Value extractMask = nullptr;
345  if (mask) {
346  extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
347  }
348  result =
349  makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350  result, /*fastmath=*/nullptr, extractMask);
351  }
352 
353  rewriter.replaceOp(rootOp, result);
354  return success();
355  }
356 };
357 
358 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
359 /// a sequence of vector.reduction ops.
360 struct TwoDimMultiReductionToReduction
361  : public OpRewritePattern<vector::MultiDimReductionOp> {
363 
364  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
365  PatternRewriter &rewriter) const override {
366  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
367  if (srcRank != 2)
368  return failure();
369 
370  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
371  return failure();
372 
373  // Vector mask setup.
374  OpBuilder::InsertionGuard guard(rewriter);
375  auto maskableOp =
376  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
377  Operation *rootOp;
378  if (maskableOp.isMasked()) {
379  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
380  rootOp = maskableOp.getMaskingOp();
381  } else {
382  rootOp = multiReductionOp;
383  }
384 
385  auto loc = multiReductionOp.getLoc();
386  Value result = rewriter.create<arith::ConstantOp>(
387  loc, multiReductionOp.getDestType(),
388  rewriter.getZeroAttr(multiReductionOp.getDestType()));
389  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
390 
391  for (int i = 0; i < outerDim; ++i) {
392  auto v = rewriter.create<vector::ExtractOp>(
393  loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
394  auto acc = rewriter.create<vector::ExtractOp>(
395  loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
396  Operation *reductionOp = rewriter.create<vector::ReductionOp>(
397  loc, multiReductionOp.getKind(), v, acc);
398 
399  // If masked, slice the mask and mask the new reduction operation.
400  if (maskableOp.isMasked()) {
401  Value mask = rewriter.create<vector::ExtractOp>(
402  loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
403  reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
404  }
405 
406  result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
407  result, i);
408  }
409 
410  rewriter.replaceOp(rootOp, result);
411  return success();
412  }
413 };
414 
415 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
416 /// form with both a single parallel and reduction dimension.
417 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
418 /// The case with a single parallel dimension is a noop and folds away
419 /// separately.
420 struct OneDimMultiReductionToTwoDim
421  : public OpRewritePattern<vector::MultiDimReductionOp> {
423 
424  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
425  PatternRewriter &rewriter) const override {
426  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
427  // Rank-1 or bail.
428  if (srcRank != 1)
429  return failure();
430 
431  // Vector mask setup.
432  OpBuilder::InsertionGuard guard(rewriter);
433  auto maskableOp =
434  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
435  Operation *rootOp;
436  Value mask;
437  if (maskableOp.isMasked()) {
438  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
439  rootOp = maskableOp.getMaskingOp();
440  mask = maskableOp.getMaskingOp().getMask();
441  } else {
442  rootOp = multiReductionOp;
443  }
444 
445  auto loc = multiReductionOp.getLoc();
446  auto srcVectorType = multiReductionOp.getSourceVectorType();
447  auto srcShape = srcVectorType.getShape();
448  auto castedType = VectorType::get(
449  ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
450  ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
451 
452  auto accType =
453  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
454  assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
455  "multi_reduction with a single dimension expects a scalar result");
456 
457  // If the unique dim is reduced and we insert a parallel in front, we need a
458  // {false, true} mask.
459  SmallVector<bool, 2> reductionMask{false, true};
460 
461  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
462  Value cast = rewriter.create<vector::ShapeCastOp>(
463  loc, castedType, multiReductionOp.getSource());
464  Value castAcc = rewriter.create<vector::BroadcastOp>(
465  loc, accType, multiReductionOp.getAcc());
466  Value castMask;
467  if (maskableOp.isMasked()) {
468  auto maskType = llvm::cast<VectorType>(mask.getType());
469  auto castMaskType = VectorType::get(
470  ArrayRef<int64_t>{1, maskType.getShape().back()},
471  maskType.getElementType(),
472  ArrayRef<bool>{false, maskType.getScalableDims().back()});
473  castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
474  }
475 
476  Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
477  loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
478  newOp = vector::maskOperation(rewriter, newOp, castMask);
479 
480  rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
481  ArrayRef<int64_t>{0});
482  return success();
483  }
484 };
485 
486 struct LowerVectorMultiReductionPass
487  : public vector::impl::LowerVectorMultiReductionBase<
488  LowerVectorMultiReductionPass> {
489  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
490  this->loweringStrategy = option;
491  }
492 
493  void runOnOperation() override {
494  Operation *op = getOperation();
495  MLIRContext *context = op->getContext();
496 
497  RewritePatternSet loweringPatterns(context);
499  this->loweringStrategy);
500 
501  if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
502  signalPassFailure();
503  }
504 
505  void getDependentDialects(DialectRegistry &registry) const override {
506  registry.insert<vector::VectorDialect>();
507  }
508 };
509 
510 } // namespace
511 
513  RewritePatternSet &patterns, VectorMultiReductionLowering options,
514  PatternBenefit benefit) {
515  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
516  patterns.getContext(), options, benefit);
517  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
518  if (options == VectorMultiReductionLowering ::InnerReduction)
519  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
520  benefit);
521  else
522  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
523  benefit);
524 }
525 
526 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
527  vector::VectorMultiReductionLowering option) {
528  return std::make_unique<LowerVectorMultiReductionPass>(option);
529 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:112
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362