MLIR  20.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
21 
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 } // namespace vector
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "vector-multi-reduction"
30 
31 using namespace mlir;
32 
33 namespace {
34 /// This file implements the following transformations as composable atomic
35 /// patterns.
36 
37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
38 /// by using vector.transpose
39 class InnerOuterDimReductionConversion
40  : public OpRewritePattern<vector::MultiDimReductionOp> {
41 public:
43 
44  explicit InnerOuterDimReductionConversion(
45  MLIRContext *context, vector::VectorMultiReductionLowering options,
46  PatternBenefit benefit = 1)
48  useInnerDimsForReduction(
49  options == vector::VectorMultiReductionLowering::InnerReduction) {}
50 
51  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52  PatternRewriter &rewriter) const override {
53  // Vector mask setup.
54  OpBuilder::InsertionGuard guard(rewriter);
55  auto maskableOp =
56  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57  Operation *rootOp;
58  if (maskableOp.isMasked()) {
59  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60  rootOp = maskableOp.getMaskingOp();
61  } else {
62  rootOp = multiReductionOp;
63  }
64 
65  auto src = multiReductionOp.getSource();
66  auto loc = multiReductionOp.getLoc();
67  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68 
69  // Separate reduction and parallel dims
70  ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72  reductionDims.end());
73  int64_t reductionSize = reductionDims.size();
74  SmallVector<int64_t, 4> parallelDims;
75  for (int64_t i = 0; i < srcRank; ++i)
76  if (!reductionDimsSet.contains(i))
77  parallelDims.push_back(i);
78 
79  // Add transpose only if inner-most/outer-most dimensions are not parallel
80  // and there are parallel dims.
81  if (parallelDims.empty())
82  return failure();
83  if (useInnerDimsForReduction &&
84  (parallelDims ==
85  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86  return failure();
87 
88  if (!useInnerDimsForReduction &&
89  (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90  reductionDims.size(),
91  parallelDims.size() + reductionDims.size()))))
92  return failure();
93 
95  if (useInnerDimsForReduction) {
96  indices.append(parallelDims.begin(), parallelDims.end());
97  indices.append(reductionDims.begin(), reductionDims.end());
98  } else {
99  indices.append(reductionDims.begin(), reductionDims.end());
100  indices.append(parallelDims.begin(), parallelDims.end());
101  }
102 
103  // If masked, transpose the original mask.
104  Value transposedMask;
105  if (maskableOp.isMasked()) {
106  transposedMask = rewriter.create<vector::TransposeOp>(
107  loc, maskableOp.getMaskingOp().getMask(), indices);
108  }
109 
110  // Transpose reduction source.
111  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
112  SmallVector<bool> reductionMask(srcRank, false);
113  for (int i = 0; i < reductionSize; ++i) {
114  if (useInnerDimsForReduction)
115  reductionMask[srcRank - i - 1] = true;
116  else
117  reductionMask[i] = true;
118  }
119 
120  Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
121  multiReductionOp.getLoc(), transposeOp.getResult(),
122  multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123  newMultiRedOp =
124  mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125 
126  rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127  return success();
128  }
129 
130 private:
131  const bool useInnerDimsForReduction;
132 };
133 
134 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135 /// dimensions are either inner most or outer most.
136 class ReduceMultiDimReductionRank
137  : public OpRewritePattern<vector::MultiDimReductionOp> {
138 public:
140 
141  explicit ReduceMultiDimReductionRank(
142  MLIRContext *context, vector::VectorMultiReductionLowering options,
143  PatternBenefit benefit = 1)
145  useInnerDimsForReduction(
146  options == vector::VectorMultiReductionLowering::InnerReduction) {}
147 
148  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149  PatternRewriter &rewriter) const override {
150  // Vector mask setup.
151  OpBuilder::InsertionGuard guard(rewriter);
152  auto maskableOp =
153  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154  Operation *rootOp;
155  if (maskableOp.isMasked()) {
156  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157  rootOp = maskableOp.getMaskingOp();
158  } else {
159  rootOp = multiReductionOp;
160  }
161 
162  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164  auto srcScalableDims =
165  multiReductionOp.getSourceVectorType().getScalableDims();
166  auto loc = multiReductionOp.getLoc();
167 
168  // If rank less than 2, nothing to do.
169  if (srcRank < 2)
170  return failure();
171 
172  // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173  // `vscale * vscale` that's currently not modelled.
174  if (llvm::count(srcScalableDims, true) > 1)
175  return failure();
176 
177  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180  return failure();
181 
182  // 1. Separate reduction and parallel dims.
183  SmallVector<int64_t, 4> parallelDims, parallelShapes;
184  SmallVector<bool, 4> parallelScalableDims;
185  SmallVector<int64_t, 4> reductionDims, reductionShapes;
186  bool isReductionDimScalable = false;
187  for (const auto &it : llvm::enumerate(reductionMask)) {
188  int64_t i = it.index();
189  bool isReduction = it.value();
190  if (isReduction) {
191  reductionDims.push_back(i);
192  reductionShapes.push_back(srcShape[i]);
193  isReductionDimScalable |= srcScalableDims[i];
194  } else {
195  parallelDims.push_back(i);
196  parallelShapes.push_back(srcShape[i]);
197  parallelScalableDims.push_back(srcScalableDims[i]);
198  }
199  }
200 
201  // 2. Compute flattened parallel and reduction sizes.
202  int flattenedParallelDim = 0;
203  int flattenedReductionDim = 0;
204  if (!parallelShapes.empty()) {
205  flattenedParallelDim = 1;
206  for (auto d : parallelShapes)
207  flattenedParallelDim *= d;
208  }
209  if (!reductionShapes.empty()) {
210  flattenedReductionDim = 1;
211  for (auto d : reductionShapes)
212  flattenedReductionDim *= d;
213  }
214  // We must at least have some parallel or some reduction.
215  assert((flattenedParallelDim || flattenedReductionDim) &&
216  "expected at least one parallel or reduction dim");
217 
218  // 3. Fail if reduction/parallel dims are not contiguous.
219  // Check parallelDims are exactly [0 .. size).
220  int64_t counter = 0;
221  if (useInnerDimsForReduction &&
222  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223  return failure();
224  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225  counter = reductionDims.size();
226  if (!useInnerDimsForReduction &&
227  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228  return failure();
229 
230  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231  // a single parallel (resp. reduction) dim.
233  SmallVector<bool, 2> scalableDims;
235  bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236  if (flattenedParallelDim) {
237  mask.push_back(false);
238  vectorShape.push_back(flattenedParallelDim);
239  scalableDims.push_back(isParallelDimScalable);
240  }
241  if (flattenedReductionDim) {
242  mask.push_back(true);
243  vectorShape.push_back(flattenedReductionDim);
244  scalableDims.push_back(isReductionDimScalable);
245  }
246  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247  std::swap(mask.front(), mask.back());
248  std::swap(vectorShape.front(), vectorShape.back());
249  std::swap(scalableDims.front(), scalableDims.back());
250  }
251 
252  Value newVectorMask;
253  if (maskableOp.isMasked()) {
254  Value vectorMask = maskableOp.getMaskingOp().getMask();
255  auto maskCastedType = VectorType::get(
256  vectorShape,
257  llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258  newVectorMask =
259  rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
260  }
261 
262  auto castedType = VectorType::get(
263  vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264  scalableDims);
265  Value cast = rewriter.create<vector::ShapeCastOp>(
266  loc, castedType, multiReductionOp.getSource());
267 
268  Value acc = multiReductionOp.getAcc();
269  if (flattenedParallelDim) {
270  auto accType = VectorType::get(
271  {flattenedParallelDim},
272  multiReductionOp.getSourceVectorType().getElementType(),
273  /*scalableDims=*/{isParallelDimScalable});
274  acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
275  }
276  // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277  // most dim as reduction.
278  Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
279  loc, cast, acc, mask, multiReductionOp.getKind());
280  newMultiDimRedOp =
281  mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
282 
283  // 7. If there are no parallel shapes, the result is a scalar.
284  // TODO: support 0-d vectors when available.
285  if (parallelShapes.empty()) {
286  rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
287  return success();
288  }
289 
290  // 8. Creates shape cast for the output n-D -> 2-D.
291  VectorType outputCastedType = VectorType::get(
292  parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293  parallelScalableDims);
294  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295  rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296  return success();
297  }
298 
299 private:
300  const bool useInnerDimsForReduction;
301 };
302 
303 /// Unrolls vector.multi_reduction with outermost reductions
304 /// and combines results
305 struct TwoDimMultiReductionToElementWise
306  : public OpRewritePattern<vector::MultiDimReductionOp> {
308 
309  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310  PatternRewriter &rewriter) const override {
311  auto maskableOp =
312  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
313  if (maskableOp.isMasked())
314  // TODO: Support masking.
315  return failure();
316 
317  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
318  // Rank-2 ["parallel", "reduce"] or bail.
319  if (srcRank != 2)
320  return failure();
321 
322  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
323  return failure();
324 
325  auto loc = multiReductionOp.getLoc();
326  ArrayRef<int64_t> srcShape =
327  multiReductionOp.getSourceVectorType().getShape();
328 
329  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
330  if (!elementType.isIntOrIndexOrFloat())
331  return failure();
332 
333  Value result = multiReductionOp.getAcc();
334  for (int64_t i = 0; i < srcShape[0]; i++) {
335  auto operand = rewriter.create<vector::ExtractOp>(
336  loc, multiReductionOp.getSource(), i);
337  result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
338  operand, result);
339  }
340 
341  rewriter.replaceOp(multiReductionOp, result);
342  return success();
343  }
344 };
345 
346 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
347 /// a sequence of vector.reduction ops.
348 struct TwoDimMultiReductionToReduction
349  : public OpRewritePattern<vector::MultiDimReductionOp> {
351 
352  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
353  PatternRewriter &rewriter) const override {
354  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
355  if (srcRank != 2)
356  return failure();
357 
358  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
359  return failure();
360 
361  // Vector mask setup.
362  OpBuilder::InsertionGuard guard(rewriter);
363  auto maskableOp =
364  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
365  Operation *rootOp;
366  if (maskableOp.isMasked()) {
367  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
368  rootOp = maskableOp.getMaskingOp();
369  } else {
370  rootOp = multiReductionOp;
371  }
372 
373  auto loc = multiReductionOp.getLoc();
374  Value result = rewriter.create<arith::ConstantOp>(
375  loc, multiReductionOp.getDestType(),
376  rewriter.getZeroAttr(multiReductionOp.getDestType()));
377  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
378 
379  for (int i = 0; i < outerDim; ++i) {
380  auto v = rewriter.create<vector::ExtractOp>(
381  loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
382  auto acc = rewriter.create<vector::ExtractOp>(
383  loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
384  Operation *reductionOp = rewriter.create<vector::ReductionOp>(
385  loc, multiReductionOp.getKind(), v, acc);
386 
387  // If masked, slice the mask and mask the new reduction operation.
388  if (maskableOp.isMasked()) {
389  Value mask = rewriter.create<vector::ExtractOp>(
390  loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
391  reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
392  }
393 
394  result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
395  result, i);
396  }
397 
398  rewriter.replaceOp(rootOp, result);
399  return success();
400  }
401 };
402 
403 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
404 /// form with both a single parallel and reduction dimension.
405 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
406 /// The case with a single parallel dimension is a noop and folds away
407 /// separately.
408 struct OneDimMultiReductionToTwoDim
409  : public OpRewritePattern<vector::MultiDimReductionOp> {
411 
412  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
413  PatternRewriter &rewriter) const override {
414  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
415  // Rank-1 or bail.
416  if (srcRank != 1)
417  return failure();
418 
419  // Vector mask setup.
420  OpBuilder::InsertionGuard guard(rewriter);
421  auto maskableOp =
422  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
423  Operation *rootOp;
424  Value mask;
425  if (maskableOp.isMasked()) {
426  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
427  rootOp = maskableOp.getMaskingOp();
428  mask = maskableOp.getMaskingOp().getMask();
429  } else {
430  rootOp = multiReductionOp;
431  }
432 
433  auto loc = multiReductionOp.getLoc();
434  auto srcVectorType = multiReductionOp.getSourceVectorType();
435  auto srcShape = srcVectorType.getShape();
436  auto castedType = VectorType::get(
437  ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
438  ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
439 
440  auto accType =
441  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
442  assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
443  "multi_reduction with a single dimension expects a scalar result");
444 
445  // If the unique dim is reduced and we insert a parallel in front, we need a
446  // {false, true} mask.
447  SmallVector<bool, 2> reductionMask{false, true};
448 
449  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
450  Value cast = rewriter.create<vector::ShapeCastOp>(
451  loc, castedType, multiReductionOp.getSource());
452  Value castAcc = rewriter.create<vector::BroadcastOp>(
453  loc, accType, multiReductionOp.getAcc());
454  Value castMask;
455  if (maskableOp.isMasked()) {
456  auto maskType = llvm::cast<VectorType>(mask.getType());
457  auto castMaskType = VectorType::get(
458  ArrayRef<int64_t>{1, maskType.getShape().back()},
459  maskType.getElementType(),
460  ArrayRef<bool>{false, maskType.getScalableDims().back()});
461  castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
462  }
463 
464  Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
465  loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
466  newOp = vector::maskOperation(rewriter, newOp, castMask);
467 
468  rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
469  ArrayRef<int64_t>{0});
470  return success();
471  }
472 };
473 
474 struct LowerVectorMultiReductionPass
475  : public vector::impl::LowerVectorMultiReductionBase<
476  LowerVectorMultiReductionPass> {
477  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
478  this->loweringStrategy = option;
479  }
480 
481  void runOnOperation() override {
482  Operation *op = getOperation();
483  MLIRContext *context = op->getContext();
484 
485  RewritePatternSet loweringPatterns(context);
487  this->loweringStrategy);
488 
489  if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
490  signalPassFailure();
491  }
492 
493  void getDependentDialects(DialectRegistry &registry) const override {
494  registry.insert<vector::VectorDialect>();
495  }
496 };
497 
498 } // namespace
499 
501  RewritePatternSet &patterns, VectorMultiReductionLowering options,
502  PatternBenefit benefit) {
503  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
504  patterns.getContext(), options, benefit);
505  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
506  if (options == VectorMultiReductionLowering ::InnerReduction)
507  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
508  benefit);
509  else
510  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
511  benefit);
512 }
513 
514 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
515  vector::VectorMultiReductionLowering option) {
516  return std::make_unique<LowerVectorMultiReductionPass>(option);
517 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362