MLIR  19.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
21 
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 } // namespace vector
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "vector-multi-reduction"
30 
31 using namespace mlir;
32 
33 namespace {
34 /// This file implements the following transformations as composable atomic
35 /// patterns.
36 
37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
38 /// by using vector.transpose
39 class InnerOuterDimReductionConversion
40  : public OpRewritePattern<vector::MultiDimReductionOp> {
41 public:
43 
44  explicit InnerOuterDimReductionConversion(
45  MLIRContext *context, vector::VectorMultiReductionLowering options,
46  PatternBenefit benefit = 1)
48  useInnerDimsForReduction(
49  options == vector::VectorMultiReductionLowering::InnerReduction) {}
50 
51  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52  PatternRewriter &rewriter) const override {
53  // Vector mask setup.
54  OpBuilder::InsertionGuard guard(rewriter);
55  auto maskableOp =
56  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57  Operation *rootOp;
58  if (maskableOp.isMasked()) {
59  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60  rootOp = maskableOp.getMaskingOp();
61  } else {
62  rootOp = multiReductionOp;
63  }
64 
65  auto src = multiReductionOp.getSource();
66  auto loc = multiReductionOp.getLoc();
67  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68 
69  // Separate reduction and parallel dims
70  auto reductionDimsRange =
71  multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
72  auto reductionDims = llvm::to_vector<4>(llvm::map_range(
73  reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
74  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
75  reductionDims.end());
76  int64_t reductionSize = reductionDims.size();
77  SmallVector<int64_t, 4> parallelDims;
78  for (int64_t i = 0; i < srcRank; ++i)
79  if (!reductionDimsSet.contains(i))
80  parallelDims.push_back(i);
81 
82  // Add transpose only if inner-most/outer-most dimensions are not parallel
83  // and there are parallel dims.
84  if (parallelDims.empty())
85  return failure();
86  if (useInnerDimsForReduction &&
87  (parallelDims ==
88  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
89  return failure();
90 
91  if (!useInnerDimsForReduction &&
92  (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
93  reductionDims.size(),
94  parallelDims.size() + reductionDims.size()))))
95  return failure();
96 
98  if (useInnerDimsForReduction) {
99  indices.append(parallelDims.begin(), parallelDims.end());
100  indices.append(reductionDims.begin(), reductionDims.end());
101  } else {
102  indices.append(reductionDims.begin(), reductionDims.end());
103  indices.append(parallelDims.begin(), parallelDims.end());
104  }
105 
106  // If masked, transpose the original mask.
107  Value transposedMask;
108  if (maskableOp.isMasked()) {
109  transposedMask = rewriter.create<vector::TransposeOp>(
110  loc, maskableOp.getMaskingOp().getMask(), indices);
111  }
112 
113  // Transpose reduction source.
114  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
115  SmallVector<bool> reductionMask(srcRank, false);
116  for (int i = 0; i < reductionSize; ++i) {
117  if (useInnerDimsForReduction)
118  reductionMask[srcRank - i - 1] = true;
119  else
120  reductionMask[i] = true;
121  }
122 
123  Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
124  multiReductionOp.getLoc(), transposeOp.getResult(),
125  multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
126  newMultiRedOp =
127  mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
128 
129  rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
130  return success();
131  }
132 
133 private:
134  const bool useInnerDimsForReduction;
135 };
136 
137 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
138 /// dimensions are either inner most or outer most.
139 class ReduceMultiDimReductionRank
140  : public OpRewritePattern<vector::MultiDimReductionOp> {
141 public:
143 
144  explicit ReduceMultiDimReductionRank(
145  MLIRContext *context, vector::VectorMultiReductionLowering options,
146  PatternBenefit benefit = 1)
148  useInnerDimsForReduction(
149  options == vector::VectorMultiReductionLowering::InnerReduction) {}
150 
151  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
152  PatternRewriter &rewriter) const override {
153  // Vector mask setup.
154  OpBuilder::InsertionGuard guard(rewriter);
155  auto maskableOp =
156  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
157  Operation *rootOp;
158  if (maskableOp.isMasked()) {
159  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
160  rootOp = maskableOp.getMaskingOp();
161  } else {
162  rootOp = multiReductionOp;
163  }
164 
165  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
166  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
167  auto srcScalableDims =
168  multiReductionOp.getSourceVectorType().getScalableDims();
169  auto loc = multiReductionOp.getLoc();
170 
171  // If rank less than 2, nothing to do.
172  if (srcRank < 2)
173  return failure();
174 
175  // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
176  // `vscale * vscale` that's currently not modelled.
177  if (llvm::count(srcScalableDims, true) > 1)
178  return failure();
179 
180  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
181  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
182  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
183  return failure();
184 
185  // 1. Separate reduction and parallel dims.
186  SmallVector<int64_t, 4> parallelDims, parallelShapes;
187  SmallVector<bool, 4> parallelScalableDims;
188  SmallVector<int64_t, 4> reductionDims, reductionShapes;
189  bool isReductionDimScalable = false;
190  for (const auto &it : llvm::enumerate(reductionMask)) {
191  int64_t i = it.index();
192  bool isReduction = it.value();
193  if (isReduction) {
194  reductionDims.push_back(i);
195  reductionShapes.push_back(srcShape[i]);
196  isReductionDimScalable |= srcScalableDims[i];
197  } else {
198  parallelDims.push_back(i);
199  parallelShapes.push_back(srcShape[i]);
200  parallelScalableDims.push_back(srcScalableDims[i]);
201  }
202  }
203 
204  // 2. Compute flattened parallel and reduction sizes.
205  int flattenedParallelDim = 0;
206  int flattenedReductionDim = 0;
207  if (!parallelShapes.empty()) {
208  flattenedParallelDim = 1;
209  for (auto d : parallelShapes)
210  flattenedParallelDim *= d;
211  }
212  if (!reductionShapes.empty()) {
213  flattenedReductionDim = 1;
214  for (auto d : reductionShapes)
215  flattenedReductionDim *= d;
216  }
217  // We must at least have some parallel or some reduction.
218  assert((flattenedParallelDim || flattenedReductionDim) &&
219  "expected at least one parallel or reduction dim");
220 
221  // 3. Fail if reduction/parallel dims are not contiguous.
222  // Check parallelDims are exactly [0 .. size).
223  int64_t counter = 0;
224  if (useInnerDimsForReduction &&
225  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
226  return failure();
227  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
228  counter = reductionDims.size();
229  if (!useInnerDimsForReduction &&
230  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
231  return failure();
232 
233  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
234  // a single parallel (resp. reduction) dim.
236  SmallVector<bool, 2> scalableDims;
238  bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
239  if (flattenedParallelDim) {
240  mask.push_back(false);
241  vectorShape.push_back(flattenedParallelDim);
242  scalableDims.push_back(isParallelDimScalable);
243  }
244  if (flattenedReductionDim) {
245  mask.push_back(true);
246  vectorShape.push_back(flattenedReductionDim);
247  scalableDims.push_back(isReductionDimScalable);
248  }
249  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
250  std::swap(mask.front(), mask.back());
251  std::swap(vectorShape.front(), vectorShape.back());
252  std::swap(scalableDims.front(), scalableDims.back());
253  }
254 
255  Value newVectorMask;
256  if (maskableOp.isMasked()) {
257  Value vectorMask = maskableOp.getMaskingOp().getMask();
258  auto maskCastedType = VectorType::get(
259  vectorShape,
260  llvm::cast<VectorType>(vectorMask.getType()).getElementType());
261  newVectorMask =
262  rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
263  }
264 
265  auto castedType = VectorType::get(
266  vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
267  scalableDims);
268  Value cast = rewriter.create<vector::ShapeCastOp>(
269  loc, castedType, multiReductionOp.getSource());
270 
271  Value acc = multiReductionOp.getAcc();
272  if (flattenedParallelDim) {
273  auto accType = VectorType::get(
274  {flattenedParallelDim},
275  multiReductionOp.getSourceVectorType().getElementType(),
276  /*scalableDims=*/{isParallelDimScalable});
277  acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
278  }
279  // 6. Creates the flattened form of vector.multi_reduction with inner/outer
280  // most dim as reduction.
281  Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
282  loc, cast, acc, mask, multiReductionOp.getKind());
283  newMultiDimRedOp =
284  mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
285 
286  // 7. If there are no parallel shapes, the result is a scalar.
287  // TODO: support 0-d vectors when available.
288  if (parallelShapes.empty()) {
289  rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
290  return success();
291  }
292 
293  // 8. Creates shape cast for the output n-D -> 2-D.
294  VectorType outputCastedType = VectorType::get(
295  parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
296  parallelScalableDims);
297  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
298  rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
299  return success();
300  }
301 
302 private:
303  const bool useInnerDimsForReduction;
304 };
305 
306 /// Unrolls vector.multi_reduction with outermost reductions
307 /// and combines results
308 struct TwoDimMultiReductionToElementWise
309  : public OpRewritePattern<vector::MultiDimReductionOp> {
311 
312  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
313  PatternRewriter &rewriter) const override {
314  auto maskableOp =
315  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
316  if (maskableOp.isMasked())
317  // TODO: Support masking.
318  return failure();
319 
320  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
321  // Rank-2 ["parallel", "reduce"] or bail.
322  if (srcRank != 2)
323  return failure();
324 
325  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
326  return failure();
327 
328  auto loc = multiReductionOp.getLoc();
329  ArrayRef<int64_t> srcShape =
330  multiReductionOp.getSourceVectorType().getShape();
331 
332  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
333  if (!elementType.isIntOrIndexOrFloat())
334  return failure();
335 
336  Value result = multiReductionOp.getAcc();
337  for (int64_t i = 0; i < srcShape[0]; i++) {
338  auto operand = rewriter.create<vector::ExtractOp>(
339  loc, multiReductionOp.getSource(), i);
340  result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
341  operand, result);
342  }
343 
344  rewriter.replaceOp(multiReductionOp, result);
345  return success();
346  }
347 };
348 
349 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
350 /// a sequence of vector.reduction ops.
351 struct TwoDimMultiReductionToReduction
352  : public OpRewritePattern<vector::MultiDimReductionOp> {
354 
355  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
356  PatternRewriter &rewriter) const override {
357  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
358  if (srcRank != 2)
359  return failure();
360 
361  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
362  return failure();
363 
364  // Vector mask setup.
365  OpBuilder::InsertionGuard guard(rewriter);
366  auto maskableOp =
367  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
368  Operation *rootOp;
369  if (maskableOp.isMasked()) {
370  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
371  rootOp = maskableOp.getMaskingOp();
372  } else {
373  rootOp = multiReductionOp;
374  }
375 
376  auto loc = multiReductionOp.getLoc();
377  Value result = rewriter.create<arith::ConstantOp>(
378  loc, multiReductionOp.getDestType(),
379  rewriter.getZeroAttr(multiReductionOp.getDestType()));
380  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
381 
382  for (int i = 0; i < outerDim; ++i) {
383  auto v = rewriter.create<vector::ExtractOp>(
384  loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
385  auto acc = rewriter.create<vector::ExtractOp>(
386  loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
387  Operation *reductionOp = rewriter.create<vector::ReductionOp>(
388  loc, multiReductionOp.getKind(), v, acc);
389 
390  // If masked, slice the mask and mask the new reduction operation.
391  if (maskableOp.isMasked()) {
392  Value mask = rewriter.create<vector::ExtractOp>(
393  loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
394  reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
395  }
396 
397  result = rewriter.create<vector::InsertElementOp>(
398  loc, reductionOp->getResult(0), result,
399  rewriter.create<arith::ConstantIndexOp>(loc, i));
400  }
401 
402  rewriter.replaceOp(rootOp, result);
403  return success();
404  }
405 };
406 
407 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
408 /// form with both a single parallel and reduction dimension.
409 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
410 /// The case with a single parallel dimension is a noop and folds away
411 /// separately.
412 struct OneDimMultiReductionToTwoDim
413  : public OpRewritePattern<vector::MultiDimReductionOp> {
415 
416  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
417  PatternRewriter &rewriter) const override {
418  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
419  // Rank-1 or bail.
420  if (srcRank != 1)
421  return failure();
422 
423  // Vector mask setup.
424  OpBuilder::InsertionGuard guard(rewriter);
425  auto maskableOp =
426  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
427  Operation *rootOp;
428  Value mask;
429  if (maskableOp.isMasked()) {
430  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
431  rootOp = maskableOp.getMaskingOp();
432  mask = maskableOp.getMaskingOp().getMask();
433  } else {
434  rootOp = multiReductionOp;
435  }
436 
437  auto loc = multiReductionOp.getLoc();
438  auto srcVectorType = multiReductionOp.getSourceVectorType();
439  auto srcShape = srcVectorType.getShape();
440  auto castedType = VectorType::get(
441  ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
442  ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
443 
444  auto accType =
445  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
446  assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
447  "multi_reduction with a single dimension expects a scalar result");
448 
449  // If the unique dim is reduced and we insert a parallel in front, we need a
450  // {false, true} mask.
451  SmallVector<bool, 2> reductionMask{false, true};
452 
453  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
454  Value cast = rewriter.create<vector::ShapeCastOp>(
455  loc, castedType, multiReductionOp.getSource());
456  Value castAcc = rewriter.create<vector::BroadcastOp>(
457  loc, accType, multiReductionOp.getAcc());
458  Value castMask;
459  if (maskableOp.isMasked()) {
460  auto maskType = llvm::cast<VectorType>(mask.getType());
461  auto castMaskType = VectorType::get(
462  ArrayRef<int64_t>{1, maskType.getShape().back()},
463  maskType.getElementType(),
464  ArrayRef<bool>{false, maskType.getScalableDims().back()});
465  castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
466  }
467 
468  Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
469  loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
470  newOp = vector::maskOperation(rewriter, newOp, castMask);
471 
472  rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
473  ArrayRef<int64_t>{0});
474  return success();
475  }
476 };
477 
478 struct LowerVectorMultiReductionPass
479  : public vector::impl::LowerVectorMultiReductionBase<
480  LowerVectorMultiReductionPass> {
481  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
482  this->loweringStrategy = option;
483  }
484 
485  void runOnOperation() override {
486  Operation *op = getOperation();
487  MLIRContext *context = op->getContext();
488 
489  RewritePatternSet loweringPatterns(context);
491  this->loweringStrategy);
492 
493  if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
494  signalPassFailure();
495  }
496 
497  void getDependentDialects(DialectRegistry &registry) const override {
498  registry.insert<vector::VectorDialect>();
499  }
500 };
501 
502 } // namespace
503 
505  RewritePatternSet &patterns, VectorMultiReductionLowering options,
506  PatternBenefit benefit) {
507  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
508  patterns.getContext(), options, benefit);
509  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
510  if (options == VectorMultiReductionLowering ::InnerReduction)
511  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
512  benefit);
513  else
514  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
515  benefit);
516 }
517 
518 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
519  vector::VectorMultiReductionLowering option) {
520  return std::make_unique<LowerVectorMultiReductionPass>(option);
521 }
static llvm::ManagedStatic< PassManagerOptions > options
static VectorShape vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362