MLIR 23.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2//
3/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.multi_reduction' operation.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/IR/Builders.h"
21
22namespace mlir {
23namespace vector {
24#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26} // namespace vector
27} // namespace mlir
28
29#define DEBUG_TYPE "vector-multi-reduction"
30
31using namespace mlir;
32
33namespace {
34/// This file implements the following transformations as composable atomic
35/// patterns.
36
37/// Converts vector.multi_reduction into inner-most/outer-most reduction form
38/// by using vector.transpose
39class InnerOuterDimReductionConversion
40 : public OpRewritePattern<vector::MultiDimReductionOp> {
41public:
42 using Base::Base;
43
44 explicit InnerOuterDimReductionConversion(
45 MLIRContext *context, vector::VectorMultiReductionLowering options,
46 PatternBenefit benefit = 1)
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
50
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52 PatternRewriter &rewriter) const override {
53 // Vector mask setup.
54 OpBuilder::InsertionGuard guard(rewriter);
55 auto maskableOp =
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57 Operation *rootOp;
58 if (maskableOp.isMasked()) {
59 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60 rootOp = maskableOp.getMaskingOp();
61 } else {
62 rootOp = multiReductionOp;
63 }
64
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68
69 // Separate reduction and parallel dims
70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72 reductionDims.end());
73 int64_t reductionSize = reductionDims.size();
74 SmallVector<int64_t, 4> parallelDims;
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
78
79 // Add transpose only if inner-most/outer-most dimensions are not parallel
80 // and there are parallel dims.
81 if (parallelDims.empty())
82 return failure();
83 if (useInnerDimsForReduction &&
84 (parallelDims ==
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86 return failure();
87
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90 reductionDims.size(),
91 parallelDims.size() + reductionDims.size()))))
92 return failure();
93
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
98 } else {
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
101 }
102
103 // If masked, transpose the original mask.
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = vector::TransposeOp::create(
107 rewriter, loc, maskableOp.getMaskingOp().getMask(), indices);
108 }
109
110 // Transpose reduction source.
111 auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices);
112 SmallVector<bool> reductionMask(srcRank, false);
113 for (int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] = true;
116 else
117 reductionMask[i] = true;
118 }
119
120 Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
121 rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123 newMultiRedOp =
124 mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125
126 rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127 return success();
128 }
129
130private:
131 const bool useInnerDimsForReduction;
132};
133
134/// Flattens vector.multi_reduction to 2D
135///
136/// Given all reduction dimensions are either inner most or outer most,
137/// flattens all reduction and parallel dimensions so that there are only 2Ds.
138///
139/// BEFORE
140/// vector.multi_reduction <add>, %vec, %acc [2, 3] : vector<2x3x4x5xi32> to
141/// vector<2x3xi32>
142/// AFTER
143/// %vec_sc = vector.shape_cast %vec
144/// %acc_sc = vector.shape_cast %acc
145/// %res = vector.multi_reduction <add>, %vec_sc, %acc_cs [1] :
146/// vector<6x20xi32> to vector<6xi32> %res_sc = vector.shape_cast %res
147class FlattenMultiReduction
148 : public OpRewritePattern<vector::MultiDimReductionOp> {
149public:
150 using Base::Base;
151
152 explicit FlattenMultiReduction(MLIRContext *context,
153 vector::VectorMultiReductionLowering options,
154 PatternBenefit benefit = 1)
156 useInnerDimsForReduction(
157 options == vector::VectorMultiReductionLowering::InnerReduction) {}
158
159 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
160 PatternRewriter &rewriter) const override {
161 // Vector mask setup.
162 OpBuilder::InsertionGuard guard(rewriter);
163 auto maskableOp =
164 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
165 Operation *rootOp;
166 if (maskableOp.isMasked()) {
167 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
168 rootOp = maskableOp.getMaskingOp();
169 } else {
170 rootOp = multiReductionOp;
171 }
172
173 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
174 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
175 auto srcScalableDims =
176 multiReductionOp.getSourceVectorType().getScalableDims();
177 auto loc = multiReductionOp.getLoc();
178
179 // If rank less than 2, nothing to do.
180 if (srcRank < 2)
181 return failure();
182
183 // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
184 // `vscale * vscale` that's currently not modelled.
185 if (llvm::count(srcScalableDims, true) > 1)
186 return failure();
187
188 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
189 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
190 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
191 return failure();
192
193 // 1. Separate reduction and parallel dims.
194 SmallVector<int64_t, 4> parallelDims, parallelShapes;
195 SmallVector<bool, 4> parallelScalableDims;
196 SmallVector<int64_t, 4> reductionDims, reductionShapes;
197 bool isReductionDimScalable = false;
198 for (const auto &it : llvm::enumerate(reductionMask)) {
199 int64_t i = it.index();
200 bool isReduction = it.value();
201 if (isReduction) {
202 reductionDims.push_back(i);
203 reductionShapes.push_back(srcShape[i]);
204 isReductionDimScalable |= srcScalableDims[i];
205 } else {
206 parallelDims.push_back(i);
207 parallelShapes.push_back(srcShape[i]);
208 parallelScalableDims.push_back(srcScalableDims[i]);
209 }
210 }
211
212 // 2. Compute flattened parallel and reduction sizes.
213 int flattenedParallelDim = 0;
214 int flattenedReductionDim = 0;
215 if (!parallelShapes.empty()) {
216 flattenedParallelDim = 1;
217 for (auto d : parallelShapes)
218 flattenedParallelDim *= d;
219 }
220 if (!reductionShapes.empty()) {
221 flattenedReductionDim = 1;
222 for (auto d : reductionShapes)
223 flattenedReductionDim *= d;
224 }
225 // We must at least have some parallel or some reduction.
226 assert((flattenedParallelDim || flattenedReductionDim) &&
227 "expected at least one parallel or reduction dim");
228
229 // 3. Fail if reduction/parallel dims are not contiguous.
230 // Check parallelDims are exactly [0 .. size).
231 int64_t counter = 0;
232 if (useInnerDimsForReduction &&
233 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
234 return failure();
235 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
236 counter = reductionDims.size();
237 if (!useInnerDimsForReduction &&
238 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
239 return failure();
240
241 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
242 // a single parallel (resp. reduction) dim.
244 SmallVector<bool, 2> scalableDims;
246 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
247 if (flattenedParallelDim) {
248 mask.push_back(false);
249 vectorShape.push_back(flattenedParallelDim);
250 scalableDims.push_back(isParallelDimScalable);
251 }
252 if (flattenedReductionDim) {
253 mask.push_back(true);
254 vectorShape.push_back(flattenedReductionDim);
255 scalableDims.push_back(isReductionDimScalable);
256 }
257 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
258 std::swap(mask.front(), mask.back());
259 std::swap(vectorShape.front(), vectorShape.back());
260 std::swap(scalableDims.front(), scalableDims.back());
261 }
262
263 Value newVectorMask;
264 if (maskableOp.isMasked()) {
265 Value vectorMask = maskableOp.getMaskingOp().getMask();
266 auto maskCastedType = VectorType::get(
268 llvm::cast<VectorType>(vectorMask.getType()).getElementType());
269 newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
270 vectorMask);
271 }
272
273 auto castedType = VectorType::get(
274 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
275 scalableDims);
276 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
277 multiReductionOp.getSource());
278
279 Value acc = multiReductionOp.getAcc();
280 if (flattenedParallelDim) {
281 auto accType = VectorType::get(
282 {flattenedParallelDim},
283 multiReductionOp.getSourceVectorType().getElementType(),
284 /*scalableDims=*/{isParallelDimScalable});
285 acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc);
286 }
287 // 6. Creates the flattened form of vector.multi_reduction with inner/outer
288 // most dim as reduction.
289 Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
290 rewriter, loc, cast, acc, mask, multiReductionOp.getKind());
291 newMultiDimRedOp =
292 mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
293
294 // 7. If there are no parallel shapes, the result is a scalar.
295 // TODO: support 0-d vectors when available.
296 if (parallelShapes.empty()) {
297 rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
298 return success();
299 }
300
301 // 8. Shape cast the flattened result back to the original n-D parallel
302 // shape.
303 VectorType outputCastedType = VectorType::get(
304 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
305 parallelScalableDims);
306 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
307 rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
308 return success();
309 }
310
311private:
312 const bool useInnerDimsForReduction;
313};
314
315/// Lowers 2D vector.multi_reduction to a squence of Arith Ops
316///
317/// The reduction dimension must be the outer-most dimension.
318///
319/// BEFORE:
320///
321/// %1 = vector.multi_reduction <mul>, %src, %acc [0] : vector<4x2xf32> to
322/// vector<2xf32>
323///
324/// AFTER:
325///
326/// // Prod 1.
327/// %vec_0 = vector.extract %src[0] : vector<2xf32> from vector<4x2xf32>
328/// %mul_0 = arith.mulf %vec_0, %acc : vector<2xf32>
329///
330/// // Prod 2.
331/// %vec_1 = vector.extract %src[1] : vector<2xf32> from vector<4x2xf32>
332/// %mul_2 = arith.mulf %vec_1, %mul_0 : vector<2xf32>
333///
334/// // Prod 3.
335/// %vec_3 = vector.extract %src[2] : vector<2xf32> from vector<4x2xf32>
336/// %mul_3 = arith.mulf %vec_3, %mul_2 : vector<2xf32>
337///
338/// // Prod 4.
339/// %vec_4 = vector.extract %src[3] : vector<2xf32> from vector<4x2xf32>
340/// %res = arith.mulf %vec_4, %mul_3 : vector<2xf32>
341struct TwoDimMultiReductionToElementWise
342 : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
343 using MaskableOpRewritePattern::MaskableOpRewritePattern;
344
345 FailureOr<Value>
346 matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
347 vector::MaskingOpInterface maskingOp,
348 PatternRewriter &rewriter) const override {
349 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
350 // Rank-2 ["parallel", "reduce"] or bail.
351 if (srcRank != 2)
352 return failure();
353
354 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
355 return failure();
356
357 Value mask = maskingOp ? maskingOp.getMask() : Value();
358
359 auto loc = multiReductionOp.getLoc();
360 Value source = multiReductionOp.getSource();
361 ArrayRef<int64_t> srcShape =
362 multiReductionOp.getSourceVectorType().getShape();
363 int outerDim = srcShape[0];
364
365 Value result = multiReductionOp.getAcc();
366 for (int64_t i = 0; i < outerDim; i++) {
367 auto v = vector::ExtractOp::create(rewriter, loc, source, i);
368 Value m = mask ? Value(vector::ExtractOp::create(rewriter, loc, mask, i))
369 : nullptr;
370 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), v,
371 result, /*fastmath=*/nullptr, m);
372 }
373
374 return result;
375 }
376};
377
378/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops.
379///
380/// The reduction dimension must be the inner-most dimension.
381///
382/// BEFORE:
383/// vector.multi_reduction <mul>, %src, %acc [1] : vector<2x4xf32> to
384/// vector<2xf32>
385///
386/// AFTER:
387/// // 1st reduction
388/// %v_0 = vector.extract %src[0] : vector<4xf32> from vector<2x4xf32>
389/// %a_0 = vector.extract %acc[0] : f32 from vector<2xf32>
390/// %red_1 = vector.reduction <mul>, %v_0, %a_1 : vector<4xf32> into f32
391/// %res_tmp = vector.insert %red_1, %res [0] : f32 into vector<2xf32>
392///
393/// // 2nd reduction
394/// %v_1 = vector.extract %src[1] : vector<4xf32> from vector<2x4xf32>
395/// %a_1 = vector.extract %acc[1] : f32 from vector<2xf32>
396/// %red_2 = vector.reduction <mul>, %v_1, %a_1 : vector<4xf32> into f32
397/// %res_final = vector.insert %red_2, %res_tmp [1] : f32 into vector<2xf32>
398struct TwoDimMultiReductionToReduction
399 : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
400 using MaskableOpRewritePattern::MaskableOpRewritePattern;
401
402 FailureOr<Value>
403 matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
404 vector::MaskingOpInterface maskingOp,
405 PatternRewriter &rewriter) const override {
406 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
407 // Rank-2 ["reduce", "parallel"] or bail.
408 if (srcRank != 2)
409 return failure();
410
411 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
412 return failure();
413
414 Value mask = maskingOp ? maskingOp.getMask() : nullptr;
415
416 auto loc = multiReductionOp.getLoc();
417 Value source = multiReductionOp.getSource();
418 Value acc = multiReductionOp.getAcc();
419 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
420
421 Value result = arith::ConstantOp::create(
422 rewriter, loc, multiReductionOp.getDestType(),
423 rewriter.getZeroAttr(multiReductionOp.getDestType()));
424
425 SmallVector<Value> vectors(outerDim);
426 for (int64_t i = 0; i < outerDim; ++i) {
427 Value v = vector::ExtractOp::create(rewriter, loc, source, i);
428 Value a = vector::ExtractOp::create(rewriter, loc, acc, i);
429
430 Operation *reductionOp = vector::ReductionOp::create(
431 rewriter, loc, multiReductionOp.getKind(), v, a);
432
433 if (mask) {
434 Value m = vector::ExtractOp::create(rewriter, loc, mask, i);
435 reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, m);
436 }
437
438 result = vector::InsertOp::create(rewriter, loc,
439 reductionOp->getResult(0), result, i);
440 }
441
442 return result;
443 }
444};
445
446/// Converts 1D vector.multi_reduction directly to vector.reduction.
447///
448/// Example:
449/// ```mlir
450/// // Before
451/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
452///
453/// // After
454/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
455/// ```
456struct OneDimMultiReductionToReduction
457 : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
458 using MaskableOpRewritePattern::MaskableOpRewritePattern;
459
460 FailureOr<Value>
461 matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
462 vector::MaskingOpInterface maskingOp,
463 PatternRewriter &rewriter) const override {
464 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
465 if (srcRank != 1)
466 return failure();
467
468 if (!multiReductionOp.isReducedDim(0))
469 return failure();
470
471 auto loc = multiReductionOp.getLoc();
472 Value mask = maskingOp ? maskingOp.getMask() : Value();
473
474 Operation *reductionOp = vector::ReductionOp::create(
475 rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
476 multiReductionOp.getAcc());
477
478 if (mask)
479 reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
480
481 return reductionOp->getResult(0);
482 }
483};
484
485struct LowerVectorMultiReductionPass
486 : public vector::impl::LowerVectorMultiReductionBase<
487 LowerVectorMultiReductionPass> {
488 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
489 this->loweringStrategy = option;
490 }
491
492 void runOnOperation() override {
493 Operation *op = getOperation();
494 MLIRContext *context = op->getContext();
495
496 RewritePatternSet patterns(context);
498 patterns, this->loweringStrategy);
499 if (failed(applyPatternsGreedily(op, std::move(patterns))))
500 signalPassFailure();
501
502 RewritePatternSet flatteningPatterns(context);
504 flatteningPatterns, this->loweringStrategy);
505 if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
506 signalPassFailure();
507
508 RewritePatternSet unrollingPatterns(context);
510 unrollingPatterns, this->loweringStrategy);
511 if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
512 signalPassFailure();
513 }
514
515 void getDependentDialects(DialectRegistry &registry) const override {
516 registry.insert<vector::VectorDialect>();
517 }
518};
519
520} // namespace
521
523 RewritePatternSet &patterns, VectorMultiReductionLowering options,
524 PatternBenefit benefit) {
525 patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
526 benefit);
527}
528
530 RewritePatternSet &patterns, VectorMultiReductionLowering options,
531 PatternBenefit benefit) {
532 patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
533}
534
536 RewritePatternSet &patterns, VectorMultiReductionLowering options,
537 PatternBenefit benefit) {
538 patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
539 if (options == VectorMultiReductionLowering ::InnerReduction)
540 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
541 benefit);
542 else
543 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
544 benefit);
545}
546
548 vector::VectorMultiReductionLowering option) {
549 return std::make_unique<LowerVectorMultiReductionPass>(option);
550}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:415
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionReorderPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.