MLIR  22.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
21 
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 } // namespace vector
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "vector-multi-reduction"
30 
31 using namespace mlir;
32 
33 namespace {
34 /// This file implements the following transformations as composable atomic
35 /// patterns.
36 
37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
38 /// by using vector.transpose
39 class InnerOuterDimReductionConversion
40  : public OpRewritePattern<vector::MultiDimReductionOp> {
41 public:
43 
44  explicit InnerOuterDimReductionConversion(
45  MLIRContext *context, vector::VectorMultiReductionLowering options,
46  PatternBenefit benefit = 1)
48  useInnerDimsForReduction(
49  options == vector::VectorMultiReductionLowering::InnerReduction) {}
50 
51  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52  PatternRewriter &rewriter) const override {
53  // Vector mask setup.
54  OpBuilder::InsertionGuard guard(rewriter);
55  auto maskableOp =
56  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57  Operation *rootOp;
58  if (maskableOp.isMasked()) {
59  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60  rootOp = maskableOp.getMaskingOp();
61  } else {
62  rootOp = multiReductionOp;
63  }
64 
65  auto src = multiReductionOp.getSource();
66  auto loc = multiReductionOp.getLoc();
67  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68 
69  // Separate reduction and parallel dims
70  ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72  reductionDims.end());
73  int64_t reductionSize = reductionDims.size();
74  SmallVector<int64_t, 4> parallelDims;
75  for (int64_t i = 0; i < srcRank; ++i)
76  if (!reductionDimsSet.contains(i))
77  parallelDims.push_back(i);
78 
79  // Add transpose only if inner-most/outer-most dimensions are not parallel
80  // and there are parallel dims.
81  if (parallelDims.empty())
82  return failure();
83  if (useInnerDimsForReduction &&
84  (parallelDims ==
85  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86  return failure();
87 
88  if (!useInnerDimsForReduction &&
89  (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90  reductionDims.size(),
91  parallelDims.size() + reductionDims.size()))))
92  return failure();
93 
95  if (useInnerDimsForReduction) {
96  indices.append(parallelDims.begin(), parallelDims.end());
97  indices.append(reductionDims.begin(), reductionDims.end());
98  } else {
99  indices.append(reductionDims.begin(), reductionDims.end());
100  indices.append(parallelDims.begin(), parallelDims.end());
101  }
102 
103  // If masked, transpose the original mask.
104  Value transposedMask;
105  if (maskableOp.isMasked()) {
106  transposedMask = vector::TransposeOp::create(
107  rewriter, loc, maskableOp.getMaskingOp().getMask(), indices);
108  }
109 
110  // Transpose reduction source.
111  auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices);
112  SmallVector<bool> reductionMask(srcRank, false);
113  for (int i = 0; i < reductionSize; ++i) {
114  if (useInnerDimsForReduction)
115  reductionMask[srcRank - i - 1] = true;
116  else
117  reductionMask[i] = true;
118  }
119 
120  Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
121  rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
122  multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123  newMultiRedOp =
124  mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125 
126  rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127  return success();
128  }
129 
130 private:
131  const bool useInnerDimsForReduction;
132 };
133 
134 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135 /// dimensions are either inner most or outer most.
136 class ReduceMultiDimReductionRank
137  : public OpRewritePattern<vector::MultiDimReductionOp> {
138 public:
140 
141  explicit ReduceMultiDimReductionRank(
142  MLIRContext *context, vector::VectorMultiReductionLowering options,
143  PatternBenefit benefit = 1)
145  useInnerDimsForReduction(
146  options == vector::VectorMultiReductionLowering::InnerReduction) {}
147 
148  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149  PatternRewriter &rewriter) const override {
150  // Vector mask setup.
151  OpBuilder::InsertionGuard guard(rewriter);
152  auto maskableOp =
153  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154  Operation *rootOp;
155  if (maskableOp.isMasked()) {
156  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157  rootOp = maskableOp.getMaskingOp();
158  } else {
159  rootOp = multiReductionOp;
160  }
161 
162  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164  auto srcScalableDims =
165  multiReductionOp.getSourceVectorType().getScalableDims();
166  auto loc = multiReductionOp.getLoc();
167 
168  // If rank less than 2, nothing to do.
169  if (srcRank < 2)
170  return failure();
171 
172  // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173  // `vscale * vscale` that's currently not modelled.
174  if (llvm::count(srcScalableDims, true) > 1)
175  return failure();
176 
177  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180  return failure();
181 
182  // 1. Separate reduction and parallel dims.
183  SmallVector<int64_t, 4> parallelDims, parallelShapes;
184  SmallVector<bool, 4> parallelScalableDims;
185  SmallVector<int64_t, 4> reductionDims, reductionShapes;
186  bool isReductionDimScalable = false;
187  for (const auto &it : llvm::enumerate(reductionMask)) {
188  int64_t i = it.index();
189  bool isReduction = it.value();
190  if (isReduction) {
191  reductionDims.push_back(i);
192  reductionShapes.push_back(srcShape[i]);
193  isReductionDimScalable |= srcScalableDims[i];
194  } else {
195  parallelDims.push_back(i);
196  parallelShapes.push_back(srcShape[i]);
197  parallelScalableDims.push_back(srcScalableDims[i]);
198  }
199  }
200 
201  // 2. Compute flattened parallel and reduction sizes.
202  int flattenedParallelDim = 0;
203  int flattenedReductionDim = 0;
204  if (!parallelShapes.empty()) {
205  flattenedParallelDim = 1;
206  for (auto d : parallelShapes)
207  flattenedParallelDim *= d;
208  }
209  if (!reductionShapes.empty()) {
210  flattenedReductionDim = 1;
211  for (auto d : reductionShapes)
212  flattenedReductionDim *= d;
213  }
214  // We must at least have some parallel or some reduction.
215  assert((flattenedParallelDim || flattenedReductionDim) &&
216  "expected at least one parallel or reduction dim");
217 
218  // 3. Fail if reduction/parallel dims are not contiguous.
219  // Check parallelDims are exactly [0 .. size).
220  int64_t counter = 0;
221  if (useInnerDimsForReduction &&
222  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223  return failure();
224  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225  counter = reductionDims.size();
226  if (!useInnerDimsForReduction &&
227  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228  return failure();
229 
230  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231  // a single parallel (resp. reduction) dim.
233  SmallVector<bool, 2> scalableDims;
235  bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236  if (flattenedParallelDim) {
237  mask.push_back(false);
238  vectorShape.push_back(flattenedParallelDim);
239  scalableDims.push_back(isParallelDimScalable);
240  }
241  if (flattenedReductionDim) {
242  mask.push_back(true);
243  vectorShape.push_back(flattenedReductionDim);
244  scalableDims.push_back(isReductionDimScalable);
245  }
246  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247  std::swap(mask.front(), mask.back());
248  std::swap(vectorShape.front(), vectorShape.back());
249  std::swap(scalableDims.front(), scalableDims.back());
250  }
251 
252  Value newVectorMask;
253  if (maskableOp.isMasked()) {
254  Value vectorMask = maskableOp.getMaskingOp().getMask();
255  auto maskCastedType = VectorType::get(
256  vectorShape,
257  llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258  newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
259  vectorMask);
260  }
261 
262  auto castedType = VectorType::get(
263  vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264  scalableDims);
265  Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
266  multiReductionOp.getSource());
267 
268  Value acc = multiReductionOp.getAcc();
269  if (flattenedParallelDim) {
270  auto accType = VectorType::get(
271  {flattenedParallelDim},
272  multiReductionOp.getSourceVectorType().getElementType(),
273  /*scalableDims=*/{isParallelDimScalable});
274  acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc);
275  }
276  // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277  // most dim as reduction.
278  Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
279  rewriter, loc, cast, acc, mask, multiReductionOp.getKind());
280  newMultiDimRedOp =
281  mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
282 
283  // 7. If there are no parallel shapes, the result is a scalar.
284  // TODO: support 0-d vectors when available.
285  if (parallelShapes.empty()) {
286  rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
287  return success();
288  }
289 
290  // 8. Creates shape cast for the output n-D -> 2-D.
291  VectorType outputCastedType = VectorType::get(
292  parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293  parallelScalableDims);
294  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295  rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296  return success();
297  }
298 
299 private:
300  const bool useInnerDimsForReduction;
301 };
302 
303 /// Unrolls vector.multi_reduction with outermost reductions
304 /// and combines results
305 struct TwoDimMultiReductionToElementWise
306  : public OpRewritePattern<vector::MultiDimReductionOp> {
308 
309  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310  PatternRewriter &rewriter) const override {
311  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
312  // Rank-2 ["parallel", "reduce"] or bail.
313  if (srcRank != 2)
314  return failure();
315 
316  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
317  return failure();
318 
319  auto loc = multiReductionOp.getLoc();
320  ArrayRef<int64_t> srcShape =
321  multiReductionOp.getSourceVectorType().getShape();
322 
323  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
324  if (!elementType.isIntOrIndexOrFloat())
325  return failure();
326 
327  OpBuilder::InsertionGuard guard(rewriter);
328  auto maskableOp =
329  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
330  Operation *rootOp;
331  Value mask = nullptr;
332  if (maskableOp.isMasked()) {
333  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
334  rootOp = maskableOp.getMaskingOp();
335  mask = maskableOp.getMaskingOp().getMask();
336  } else {
337  rootOp = multiReductionOp;
338  }
339 
340  Value result = multiReductionOp.getAcc();
341  for (int64_t i = 0; i < srcShape[0]; i++) {
342  auto operand = vector::ExtractOp::create(rewriter, loc,
343  multiReductionOp.getSource(), i);
344  Value extractMask = nullptr;
345  if (mask) {
346  extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
347  }
348  result =
349  makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350  result, /*fastmath=*/nullptr, extractMask);
351  }
352 
353  rewriter.replaceOp(rootOp, result);
354  return success();
355  }
356 };
357 
358 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
359 /// a sequence of vector.reduction ops.
360 struct TwoDimMultiReductionToReduction
361  : public OpRewritePattern<vector::MultiDimReductionOp> {
363 
364  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
365  PatternRewriter &rewriter) const override {
366  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
367  if (srcRank != 2)
368  return failure();
369 
370  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
371  return failure();
372 
373  // Vector mask setup.
374  OpBuilder::InsertionGuard guard(rewriter);
375  auto maskableOp =
376  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
377  Operation *rootOp;
378  if (maskableOp.isMasked()) {
379  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
380  rootOp = maskableOp.getMaskingOp();
381  } else {
382  rootOp = multiReductionOp;
383  }
384 
385  auto loc = multiReductionOp.getLoc();
386  Value result = arith::ConstantOp::create(
387  rewriter, loc, multiReductionOp.getDestType(),
388  rewriter.getZeroAttr(multiReductionOp.getDestType()));
389  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
390 
391  for (int i = 0; i < outerDim; ++i) {
392  auto v = vector::ExtractOp::create(
393  rewriter, loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
394  auto acc = vector::ExtractOp::create(
395  rewriter, loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
396  Operation *reductionOp = vector::ReductionOp::create(
397  rewriter, loc, multiReductionOp.getKind(), v, acc);
398 
399  // If masked, slice the mask and mask the new reduction operation.
400  if (maskableOp.isMasked()) {
401  Value mask = vector::ExtractOp::create(
402  rewriter, loc, maskableOp.getMaskingOp().getMask(),
403  ArrayRef<int64_t>{i});
404  reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
405  }
406 
407  result = vector::InsertOp::create(rewriter, loc,
408  reductionOp->getResult(0), result, i);
409  }
410 
411  rewriter.replaceOp(rootOp, result);
412  return success();
413  }
414 };
415 
416 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
417 /// form with both a single parallel and reduction dimension.
418 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
419 /// The case with a single parallel dimension is a noop and folds away
420 /// separately.
421 struct OneDimMultiReductionToTwoDim
422  : public OpRewritePattern<vector::MultiDimReductionOp> {
424 
425  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
426  PatternRewriter &rewriter) const override {
427  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
428  // Rank-1 or bail.
429  if (srcRank != 1)
430  return failure();
431 
432  // Vector mask setup.
433  OpBuilder::InsertionGuard guard(rewriter);
434  auto maskableOp =
435  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
436  Operation *rootOp;
437  Value mask;
438  if (maskableOp.isMasked()) {
439  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
440  rootOp = maskableOp.getMaskingOp();
441  mask = maskableOp.getMaskingOp().getMask();
442  } else {
443  rootOp = multiReductionOp;
444  }
445 
446  auto loc = multiReductionOp.getLoc();
447  auto srcVectorType = multiReductionOp.getSourceVectorType();
448  auto srcShape = srcVectorType.getShape();
449  auto castedType = VectorType::get(
450  ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
451  ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
452 
453  auto accType =
454  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
455  assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
456  "multi_reduction with a single dimension expects a scalar result");
457 
458  // If the unique dim is reduced and we insert a parallel in front, we need a
459  // {false, true} mask.
460  SmallVector<bool, 2> reductionMask{false, true};
461 
462  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
463  Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
464  multiReductionOp.getSource());
465  Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
466  multiReductionOp.getAcc());
467  Value castMask;
468  if (maskableOp.isMasked()) {
469  auto maskType = llvm::cast<VectorType>(mask.getType());
470  auto castMaskType = VectorType::get(
471  ArrayRef<int64_t>{1, maskType.getShape().back()},
472  maskType.getElementType(),
473  ArrayRef<bool>{false, maskType.getScalableDims().back()});
474  castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
475  }
476 
477  Operation *newOp = vector::MultiDimReductionOp::create(
478  rewriter, loc, cast, castAcc, reductionMask,
479  multiReductionOp.getKind());
480  newOp = vector::maskOperation(rewriter, newOp, castMask);
481 
482  rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
483  ArrayRef<int64_t>{0});
484  return success();
485  }
486 };
487 
488 struct LowerVectorMultiReductionPass
489  : public vector::impl::LowerVectorMultiReductionBase<
490  LowerVectorMultiReductionPass> {
491  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
492  this->loweringStrategy = option;
493  }
494 
495  void runOnOperation() override {
496  Operation *op = getOperation();
497  MLIRContext *context = op->getContext();
498 
499  RewritePatternSet loweringPatterns(context);
501  this->loweringStrategy);
502 
503  if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
504  signalPassFailure();
505  }
506 
507  void getDependentDialects(DialectRegistry &registry) const override {
508  registry.insert<vector::VectorDialect>();
509  }
510 };
511 
512 } // namespace
513 
515  RewritePatternSet &patterns, VectorMultiReductionLowering options,
516  PatternBenefit benefit) {
517  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
518  patterns.getContext(), options, benefit);
519  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
520  if (options == VectorMultiReductionLowering ::InnerReduction)
521  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
522  benefit);
523  else
524  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
525  benefit);
526 }
527 
528 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
529  vector::VectorMultiReductionLowering option) {
530  return std::make_unique<LowerVectorMultiReductionPass>(option);
531 }
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319