MLIR  20.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
21 
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 } // namespace vector
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "vector-multi-reduction"
30 
31 using namespace mlir;
32 
33 namespace {
34 /// This file implements the following transformations as composable atomic
35 /// patterns.
36 
37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
38 /// by using vector.transpose
39 class InnerOuterDimReductionConversion
40  : public OpRewritePattern<vector::MultiDimReductionOp> {
41 public:
43 
44  explicit InnerOuterDimReductionConversion(
45  MLIRContext *context, vector::VectorMultiReductionLowering options,
46  PatternBenefit benefit = 1)
48  useInnerDimsForReduction(
49  options == vector::VectorMultiReductionLowering::InnerReduction) {}
50 
51  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52  PatternRewriter &rewriter) const override {
53  // Vector mask setup.
54  OpBuilder::InsertionGuard guard(rewriter);
55  auto maskableOp =
56  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57  Operation *rootOp;
58  if (maskableOp.isMasked()) {
59  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60  rootOp = maskableOp.getMaskingOp();
61  } else {
62  rootOp = multiReductionOp;
63  }
64 
65  auto src = multiReductionOp.getSource();
66  auto loc = multiReductionOp.getLoc();
67  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68 
69  // Separate reduction and parallel dims
70  ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71  llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72  reductionDims.end());
73  int64_t reductionSize = reductionDims.size();
74  SmallVector<int64_t, 4> parallelDims;
75  for (int64_t i = 0; i < srcRank; ++i)
76  if (!reductionDimsSet.contains(i))
77  parallelDims.push_back(i);
78 
79  // Add transpose only if inner-most/outer-most dimensions are not parallel
80  // and there are parallel dims.
81  if (parallelDims.empty())
82  return failure();
83  if (useInnerDimsForReduction &&
84  (parallelDims ==
85  llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86  return failure();
87 
88  if (!useInnerDimsForReduction &&
89  (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90  reductionDims.size(),
91  parallelDims.size() + reductionDims.size()))))
92  return failure();
93 
95  if (useInnerDimsForReduction) {
96  indices.append(parallelDims.begin(), parallelDims.end());
97  indices.append(reductionDims.begin(), reductionDims.end());
98  } else {
99  indices.append(reductionDims.begin(), reductionDims.end());
100  indices.append(parallelDims.begin(), parallelDims.end());
101  }
102 
103  // If masked, transpose the original mask.
104  Value transposedMask;
105  if (maskableOp.isMasked()) {
106  transposedMask = rewriter.create<vector::TransposeOp>(
107  loc, maskableOp.getMaskingOp().getMask(), indices);
108  }
109 
110  // Transpose reduction source.
111  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
112  SmallVector<bool> reductionMask(srcRank, false);
113  for (int i = 0; i < reductionSize; ++i) {
114  if (useInnerDimsForReduction)
115  reductionMask[srcRank - i - 1] = true;
116  else
117  reductionMask[i] = true;
118  }
119 
120  Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
121  multiReductionOp.getLoc(), transposeOp.getResult(),
122  multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123  newMultiRedOp =
124  mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125 
126  rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127  return success();
128  }
129 
130 private:
131  const bool useInnerDimsForReduction;
132 };
133 
134 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135 /// dimensions are either inner most or outer most.
136 class ReduceMultiDimReductionRank
137  : public OpRewritePattern<vector::MultiDimReductionOp> {
138 public:
140 
141  explicit ReduceMultiDimReductionRank(
142  MLIRContext *context, vector::VectorMultiReductionLowering options,
143  PatternBenefit benefit = 1)
145  useInnerDimsForReduction(
146  options == vector::VectorMultiReductionLowering::InnerReduction) {}
147 
148  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149  PatternRewriter &rewriter) const override {
150  // Vector mask setup.
151  OpBuilder::InsertionGuard guard(rewriter);
152  auto maskableOp =
153  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154  Operation *rootOp;
155  if (maskableOp.isMasked()) {
156  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157  rootOp = maskableOp.getMaskingOp();
158  } else {
159  rootOp = multiReductionOp;
160  }
161 
162  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163  auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164  auto srcScalableDims =
165  multiReductionOp.getSourceVectorType().getScalableDims();
166  auto loc = multiReductionOp.getLoc();
167 
168  // If rank less than 2, nothing to do.
169  if (srcRank < 2)
170  return failure();
171 
172  // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173  // `vscale * vscale` that's currently not modelled.
174  if (llvm::count(srcScalableDims, true) > 1)
175  return failure();
176 
177  // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178  SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179  if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180  return failure();
181 
182  // 1. Separate reduction and parallel dims.
183  SmallVector<int64_t, 4> parallelDims, parallelShapes;
184  SmallVector<bool, 4> parallelScalableDims;
185  SmallVector<int64_t, 4> reductionDims, reductionShapes;
186  bool isReductionDimScalable = false;
187  for (const auto &it : llvm::enumerate(reductionMask)) {
188  int64_t i = it.index();
189  bool isReduction = it.value();
190  if (isReduction) {
191  reductionDims.push_back(i);
192  reductionShapes.push_back(srcShape[i]);
193  isReductionDimScalable |= srcScalableDims[i];
194  } else {
195  parallelDims.push_back(i);
196  parallelShapes.push_back(srcShape[i]);
197  parallelScalableDims.push_back(srcScalableDims[i]);
198  }
199  }
200 
201  // 2. Compute flattened parallel and reduction sizes.
202  int flattenedParallelDim = 0;
203  int flattenedReductionDim = 0;
204  if (!parallelShapes.empty()) {
205  flattenedParallelDim = 1;
206  for (auto d : parallelShapes)
207  flattenedParallelDim *= d;
208  }
209  if (!reductionShapes.empty()) {
210  flattenedReductionDim = 1;
211  for (auto d : reductionShapes)
212  flattenedReductionDim *= d;
213  }
214  // We must at least have some parallel or some reduction.
215  assert((flattenedParallelDim || flattenedReductionDim) &&
216  "expected at least one parallel or reduction dim");
217 
218  // 3. Fail if reduction/parallel dims are not contiguous.
219  // Check parallelDims are exactly [0 .. size).
220  int64_t counter = 0;
221  if (useInnerDimsForReduction &&
222  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223  return failure();
224  // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225  counter = reductionDims.size();
226  if (!useInnerDimsForReduction &&
227  llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228  return failure();
229 
230  // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231  // a single parallel (resp. reduction) dim.
233  SmallVector<bool, 2> scalableDims;
235  bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236  if (flattenedParallelDim) {
237  mask.push_back(false);
238  vectorShape.push_back(flattenedParallelDim);
239  scalableDims.push_back(isParallelDimScalable);
240  }
241  if (flattenedReductionDim) {
242  mask.push_back(true);
243  vectorShape.push_back(flattenedReductionDim);
244  scalableDims.push_back(isReductionDimScalable);
245  }
246  if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247  std::swap(mask.front(), mask.back());
248  std::swap(vectorShape.front(), vectorShape.back());
249  std::swap(scalableDims.front(), scalableDims.back());
250  }
251 
252  Value newVectorMask;
253  if (maskableOp.isMasked()) {
254  Value vectorMask = maskableOp.getMaskingOp().getMask();
255  auto maskCastedType = VectorType::get(
256  vectorShape,
257  llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258  newVectorMask =
259  rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
260  }
261 
262  auto castedType = VectorType::get(
263  vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264  scalableDims);
265  Value cast = rewriter.create<vector::ShapeCastOp>(
266  loc, castedType, multiReductionOp.getSource());
267 
268  Value acc = multiReductionOp.getAcc();
269  if (flattenedParallelDim) {
270  auto accType = VectorType::get(
271  {flattenedParallelDim},
272  multiReductionOp.getSourceVectorType().getElementType(),
273  /*scalableDims=*/{isParallelDimScalable});
274  acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
275  }
276  // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277  // most dim as reduction.
278  Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
279  loc, cast, acc, mask, multiReductionOp.getKind());
280  newMultiDimRedOp =
281  mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
282 
283  // 7. If there are no parallel shapes, the result is a scalar.
284  // TODO: support 0-d vectors when available.
285  if (parallelShapes.empty()) {
286  rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
287  return success();
288  }
289 
290  // 8. Creates shape cast for the output n-D -> 2-D.
291  VectorType outputCastedType = VectorType::get(
292  parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293  parallelScalableDims);
294  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295  rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296  return success();
297  }
298 
299 private:
300  const bool useInnerDimsForReduction;
301 };
302 
303 /// Unrolls vector.multi_reduction with outermost reductions
304 /// and combines results
305 struct TwoDimMultiReductionToElementWise
306  : public OpRewritePattern<vector::MultiDimReductionOp> {
308 
309  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310  PatternRewriter &rewriter) const override {
311  auto maskableOp =
312  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
313  if (maskableOp.isMasked())
314  // TODO: Support masking.
315  return failure();
316 
317  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
318  // Rank-2 ["parallel", "reduce"] or bail.
319  if (srcRank != 2)
320  return failure();
321 
322  if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
323  return failure();
324 
325  auto loc = multiReductionOp.getLoc();
326  ArrayRef<int64_t> srcShape =
327  multiReductionOp.getSourceVectorType().getShape();
328 
329  Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
330  if (!elementType.isIntOrIndexOrFloat())
331  return failure();
332 
333  Value result = multiReductionOp.getAcc();
334  for (int64_t i = 0; i < srcShape[0]; i++) {
335  auto operand = rewriter.create<vector::ExtractOp>(
336  loc, multiReductionOp.getSource(), i);
337  result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
338  operand, result);
339  }
340 
341  rewriter.replaceOp(multiReductionOp, result);
342  return success();
343  }
344 };
345 
346 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
347 /// a sequence of vector.reduction ops.
348 struct TwoDimMultiReductionToReduction
349  : public OpRewritePattern<vector::MultiDimReductionOp> {
351 
352  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
353  PatternRewriter &rewriter) const override {
354  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
355  if (srcRank != 2)
356  return failure();
357 
358  if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
359  return failure();
360 
361  // Vector mask setup.
362  OpBuilder::InsertionGuard guard(rewriter);
363  auto maskableOp =
364  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
365  Operation *rootOp;
366  if (maskableOp.isMasked()) {
367  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
368  rootOp = maskableOp.getMaskingOp();
369  } else {
370  rootOp = multiReductionOp;
371  }
372 
373  auto loc = multiReductionOp.getLoc();
374  Value result = rewriter.create<arith::ConstantOp>(
375  loc, multiReductionOp.getDestType(),
376  rewriter.getZeroAttr(multiReductionOp.getDestType()));
377  int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
378 
379  for (int i = 0; i < outerDim; ++i) {
380  auto v = rewriter.create<vector::ExtractOp>(
381  loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
382  auto acc = rewriter.create<vector::ExtractOp>(
383  loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
384  Operation *reductionOp = rewriter.create<vector::ReductionOp>(
385  loc, multiReductionOp.getKind(), v, acc);
386 
387  // If masked, slice the mask and mask the new reduction operation.
388  if (maskableOp.isMasked()) {
389  Value mask = rewriter.create<vector::ExtractOp>(
390  loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
391  reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
392  }
393 
394  result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
395  result, i);
396  }
397 
398  rewriter.replaceOp(rootOp, result);
399  return success();
400  }
401 };
402 
403 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
404 /// form with both a single parallel and reduction dimension.
405 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
406 /// The case with a single parallel dimension is a noop and folds away
407 /// separately.
408 struct OneDimMultiReductionToTwoDim
409  : public OpRewritePattern<vector::MultiDimReductionOp> {
411 
412  LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
413  PatternRewriter &rewriter) const override {
414  auto srcRank = multiReductionOp.getSourceVectorType().getRank();
415  // Rank-1 or bail.
416  if (srcRank != 1)
417  return failure();
418 
419  // Vector mask setup.
420  OpBuilder::InsertionGuard guard(rewriter);
421  auto maskableOp =
422  cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
423  Operation *rootOp;
424  Value mask;
425  if (maskableOp.isMasked()) {
426  rewriter.setInsertionPoint(maskableOp.getMaskingOp());
427  rootOp = maskableOp.getMaskingOp();
428  mask = maskableOp.getMaskingOp().getMask();
429  } else {
430  rootOp = multiReductionOp;
431  }
432 
433  auto loc = multiReductionOp.getLoc();
434  auto srcVectorType = multiReductionOp.getSourceVectorType();
435  auto srcShape = srcVectorType.getShape();
436  auto castedType = VectorType::get(
437  ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
438  ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
439 
440  auto accType =
441  VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
442  assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
443  "multi_reduction with a single dimension expects a scalar result");
444 
445  // If the unique dim is reduced and we insert a parallel in front, we need a
446  // {false, true} mask.
447  SmallVector<bool, 2> reductionMask{false, true};
448 
449  /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
450  Value cast = rewriter.create<vector::ShapeCastOp>(
451  loc, castedType, multiReductionOp.getSource());
452  Value castAcc = rewriter.create<vector::BroadcastOp>(
453  loc, accType, multiReductionOp.getAcc());
454  Value castMask;
455  if (maskableOp.isMasked()) {
456  auto maskType = llvm::cast<VectorType>(mask.getType());
457  auto castMaskType = VectorType::get(
458  ArrayRef<int64_t>{1, maskType.getShape().back()},
459  maskType.getElementType(),
460  ArrayRef<bool>{false, maskType.getScalableDims().back()});
461  castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
462  }
463 
464  Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
465  loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
466  newOp = vector::maskOperation(rewriter, newOp, castMask);
467 
468  rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
469  ArrayRef<int64_t>{0});
470  return success();
471  }
472 };
473 
474 struct LowerVectorMultiReductionPass
475  : public vector::impl::LowerVectorMultiReductionBase<
476  LowerVectorMultiReductionPass> {
477  LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
478  this->loweringStrategy = option;
479  }
480 
481  void runOnOperation() override {
482  Operation *op = getOperation();
483  MLIRContext *context = op->getContext();
484 
485  RewritePatternSet loweringPatterns(context);
487  this->loweringStrategy);
488 
489  if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
490  signalPassFailure();
491  }
492 
493  void getDependentDialects(DialectRegistry &registry) const override {
494  registry.insert<vector::VectorDialect>();
495  }
496 };
497 
498 } // namespace
499 
501  RewritePatternSet &patterns, VectorMultiReductionLowering options,
502  PatternBenefit benefit) {
503  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
504  patterns.getContext(), options, benefit);
505  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
506  if (options == VectorMultiReductionLowering ::InnerReduction)
507  patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
508  benefit);
509  else
510  patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
511  benefit);
512 }
513 
514 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
515  vector::VectorMultiReductionLowering option) {
516  return std::make_unique<LowerVectorMultiReductionPass>(option);
517 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362