MLIR 22.0.0git
LowerVectorMultiReduction.cpp
Go to the documentation of this file.
1//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2//
3/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.multi_reduction' operation.
11//
12//===----------------------------------------------------------------------===//
13
18#include "mlir/IR/Builders.h"
21
22namespace mlir {
23namespace vector {
24#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26} // namespace vector
27} // namespace mlir
28
29#define DEBUG_TYPE "vector-multi-reduction"
30
31using namespace mlir;
32
33namespace {
34/// This file implements the following transformations as composable atomic
35/// patterns.
36
37/// Converts vector.multi_reduction into inner-most/outer-most reduction form
38/// by using vector.transpose
39class InnerOuterDimReductionConversion
40 : public OpRewritePattern<vector::MultiDimReductionOp> {
41public:
42 using Base::Base;
43
44 explicit InnerOuterDimReductionConversion(
45 MLIRContext *context, vector::VectorMultiReductionLowering options,
46 PatternBenefit benefit = 1)
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
50
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52 PatternRewriter &rewriter) const override {
53 // Vector mask setup.
54 OpBuilder::InsertionGuard guard(rewriter);
55 auto maskableOp =
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57 Operation *rootOp;
58 if (maskableOp.isMasked()) {
59 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60 rootOp = maskableOp.getMaskingOp();
61 } else {
62 rootOp = multiReductionOp;
63 }
64
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68
69 // Separate reduction and parallel dims
70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72 reductionDims.end());
73 int64_t reductionSize = reductionDims.size();
74 SmallVector<int64_t, 4> parallelDims;
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
78
79 // Add transpose only if inner-most/outer-most dimensions are not parallel
80 // and there are parallel dims.
81 if (parallelDims.empty())
82 return failure();
83 if (useInnerDimsForReduction &&
84 (parallelDims ==
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86 return failure();
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90 reductionDims.size(),
91 parallelDims.size() + reductionDims.size()))))
92 return failure();
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
98 } else {
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
101 }
103 // If masked, transpose the original mask.
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = vector::TransposeOp::create(
107 rewriter, loc, maskableOp.getMaskingOp().getMask(), indices);
108 }
109
110 // Transpose reduction source.
111 auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices);
112 SmallVector<bool> reductionMask(srcRank, false);
113 for (int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] = true;
116 else
117 reductionMask[i] = true;
119
120 Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
121 rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123 newMultiRedOp =
124 mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125
126 rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127 return success();
128 }
129
130private:
131 const bool useInnerDimsForReduction;
133
134/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135/// dimensions are either inner most or outer most.
136class ReduceMultiDimReductionRank
137 : public OpRewritePattern<vector::MultiDimReductionOp> {
138public:
139 using Base::Base;
140
141 explicit ReduceMultiDimReductionRank(
142 MLIRContext *context, vector::VectorMultiReductionLowering options,
143 PatternBenefit benefit = 1)
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
147
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149 PatternRewriter &rewriter) const override {
150 // Vector mask setup.
151 OpBuilder::InsertionGuard guard(rewriter);
152 auto maskableOp =
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154 Operation *rootOp;
155 if (maskableOp.isMasked()) {
156 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157 rootOp = maskableOp.getMaskingOp();
158 } else {
159 rootOp = multiReductionOp;
160 }
161
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.getLoc();
167
168 // If rank less than 2, nothing to do.
169 if (srcRank < 2)
170 return failure();
171
172 // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173 // `vscale * vscale` that's currently not modelled.
174 if (llvm::count(srcScalableDims, true) > 1)
175 return failure();
176
177 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180 return failure();
181
182 // 1. Separate reduction and parallel dims.
183 SmallVector<int64_t, 4> parallelDims, parallelShapes;
184 SmallVector<bool, 4> parallelScalableDims;
185 SmallVector<int64_t, 4> reductionDims, reductionShapes;
186 bool isReductionDimScalable = false;
187 for (const auto &it : llvm::enumerate(reductionMask)) {
188 int64_t i = it.index();
189 bool isReduction = it.value();
190 if (isReduction) {
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
194 } else {
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
198 }
199 }
200
201 // 2. Compute flattened parallel and reduction sizes.
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (auto d : parallelShapes)
207 flattenedParallelDim *= d;
208 }
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (auto d : reductionShapes)
212 flattenedReductionDim *= d;
213 }
214 // We must at least have some parallel or some reduction.
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
217
218 // 3. Fail if reduction/parallel dims are not contiguous.
219 // Check parallelDims are exactly [0 .. size).
220 int64_t counter = 0;
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223 return failure();
224 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228 return failure();
229
230 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231 // a single parallel (resp. reduction) dim.
233 SmallVector<bool, 2> scalableDims;
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236 if (flattenedParallelDim) {
237 mask.push_back(false);
238 vectorShape.push_back(flattenedParallelDim);
239 scalableDims.push_back(isParallelDimScalable);
240 }
241 if (flattenedReductionDim) {
242 mask.push_back(true);
243 vectorShape.push_back(flattenedReductionDim);
244 scalableDims.push_back(isReductionDimScalable);
245 }
246 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
248 std::swap(vectorShape.front(), vectorShape.back());
249 std::swap(scalableDims.front(), scalableDims.back());
250 }
251
252 Value newVectorMask;
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
255 auto maskCastedType = VectorType::get(
257 llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258 newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
259 vectorMask);
260 }
261
262 auto castedType = VectorType::get(
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264 scalableDims);
265 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
266 multiReductionOp.getSource());
267
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
270 auto accType = VectorType::get(
271 {flattenedParallelDim},
272 multiReductionOp.getSourceVectorType().getElementType(),
273 /*scalableDims=*/{isParallelDimScalable});
274 acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc);
275 }
276 // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277 // most dim as reduction.
278 Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
279 rewriter, loc, cast, acc, mask, multiReductionOp.getKind());
280 newMultiDimRedOp =
281 mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
282
283 // 7. If there are no parallel shapes, the result is a scalar.
284 // TODO: support 0-d vectors when available.
285 if (parallelShapes.empty()) {
286 rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
287 return success();
288 }
289
290 // 8. Creates shape cast for the output n-D -> 2-D.
291 VectorType outputCastedType = VectorType::get(
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
294 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295 rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296 return success();
297 }
298
299private:
300 const bool useInnerDimsForReduction;
301};
302
303/// Unrolls vector.multi_reduction with outermost reductions
304/// and combines results
305struct TwoDimMultiReductionToElementWise
306 : public OpRewritePattern<vector::MultiDimReductionOp> {
307 using Base::Base;
308
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310 PatternRewriter &rewriter) const override {
311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
312 // Rank-2 ["parallel", "reduce"] or bail.
313 if (srcRank != 2)
314 return failure();
315
316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
317 return failure();
318
319 auto loc = multiReductionOp.getLoc();
320 ArrayRef<int64_t> srcShape =
321 multiReductionOp.getSourceVectorType().getShape();
322
323 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
324 if (!elementType.isIntOrIndexOrFloat())
325 return failure();
326
327 OpBuilder::InsertionGuard guard(rewriter);
328 auto maskableOp =
329 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
330 Operation *rootOp;
331 Value mask = nullptr;
332 if (maskableOp.isMasked()) {
333 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
334 rootOp = maskableOp.getMaskingOp();
335 mask = maskableOp.getMaskingOp().getMask();
336 } else {
337 rootOp = multiReductionOp;
338 }
339
340 Value result = multiReductionOp.getAcc();
341 for (int64_t i = 0; i < srcShape[0]; i++) {
342 auto operand = vector::ExtractOp::create(rewriter, loc,
343 multiReductionOp.getSource(), i);
344 Value extractMask = nullptr;
345 if (mask) {
346 extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
347 }
348 result =
349 makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350 result, /*fastmath=*/nullptr, extractMask);
351 }
352
353 rewriter.replaceOp(rootOp, result);
354 return success();
355 }
356};
357
358/// Converts 2d vector.multi_reduction with inner most reduction dimension into
359/// a sequence of vector.reduction ops.
360struct TwoDimMultiReductionToReduction
361 : public OpRewritePattern<vector::MultiDimReductionOp> {
362 using Base::Base;
363
364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
365 PatternRewriter &rewriter) const override {
366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
367 if (srcRank != 2)
368 return failure();
369
370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
371 return failure();
372
373 // Vector mask setup.
374 OpBuilder::InsertionGuard guard(rewriter);
375 auto maskableOp =
376 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
377 Operation *rootOp;
378 if (maskableOp.isMasked()) {
379 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
380 rootOp = maskableOp.getMaskingOp();
381 } else {
382 rootOp = multiReductionOp;
383 }
384
385 auto loc = multiReductionOp.getLoc();
386 Value result = arith::ConstantOp::create(
387 rewriter, loc, multiReductionOp.getDestType(),
388 rewriter.getZeroAttr(multiReductionOp.getDestType()));
389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
390
391 for (int i = 0; i < outerDim; ++i) {
392 auto v = vector::ExtractOp::create(
393 rewriter, loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
394 auto acc = vector::ExtractOp::create(
395 rewriter, loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
396 Operation *reductionOp = vector::ReductionOp::create(
397 rewriter, loc, multiReductionOp.getKind(), v, acc);
398
399 // If masked, slice the mask and mask the new reduction operation.
400 if (maskableOp.isMasked()) {
401 Value mask = vector::ExtractOp::create(
402 rewriter, loc, maskableOp.getMaskingOp().getMask(),
403 ArrayRef<int64_t>{i});
404 reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
405 }
406
407 result = vector::InsertOp::create(rewriter, loc,
408 reductionOp->getResult(0), result, i);
409 }
410
411 rewriter.replaceOp(rootOp, result);
412 return success();
413 }
414};
415
416/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
417/// form with both a single parallel and reduction dimension.
418/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
419/// The case with a single parallel dimension is a noop and folds away
420/// separately.
421struct OneDimMultiReductionToTwoDim
422 : public OpRewritePattern<vector::MultiDimReductionOp> {
423 using Base::Base;
424
425 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
426 PatternRewriter &rewriter) const override {
427 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
428 // Rank-1 or bail.
429 if (srcRank != 1)
430 return failure();
431
432 // Vector mask setup.
433 OpBuilder::InsertionGuard guard(rewriter);
434 auto maskableOp =
435 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
436 Operation *rootOp;
437 Value mask;
438 if (maskableOp.isMasked()) {
439 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
440 rootOp = maskableOp.getMaskingOp();
441 mask = maskableOp.getMaskingOp().getMask();
442 } else {
443 rootOp = multiReductionOp;
444 }
445
446 auto loc = multiReductionOp.getLoc();
447 auto srcVectorType = multiReductionOp.getSourceVectorType();
448 auto srcShape = srcVectorType.getShape();
449 auto castedType = VectorType::get(
450 ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
451 ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
452
453 auto accType =
454 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
455 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
456 "multi_reduction with a single dimension expects a scalar result");
457
458 // If the unique dim is reduced and we insert a parallel in front, we need a
459 // {false, true} mask.
460 SmallVector<bool, 2> reductionMask{false, true};
461
462 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
463 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
464 multiReductionOp.getSource());
465 Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
466 multiReductionOp.getAcc());
467 Value castMask;
468 if (maskableOp.isMasked()) {
469 auto maskType = llvm::cast<VectorType>(mask.getType());
470 auto castMaskType = VectorType::get(
471 ArrayRef<int64_t>{1, maskType.getShape().back()},
472 maskType.getElementType(),
473 ArrayRef<bool>{false, maskType.getScalableDims().back()});
474 castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
475 }
476
477 Operation *newOp = vector::MultiDimReductionOp::create(
478 rewriter, loc, cast, castAcc, reductionMask,
479 multiReductionOp.getKind());
480 newOp = vector::maskOperation(rewriter, newOp, castMask);
481
482 rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
483 ArrayRef<int64_t>{0});
484 return success();
485 }
486};
487
488struct LowerVectorMultiReductionPass
490 LowerVectorMultiReductionPass> {
491 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
492 this->loweringStrategy = option;
493 }
494
495 void runOnOperation() override {
496 Operation *op = getOperation();
497 MLIRContext *context = op->getContext();
498
499 RewritePatternSet loweringPatterns(context);
501 this->loweringStrategy);
502
503 if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
504 signalPassFailure();
505 }
506
507 void getDependentDialects(DialectRegistry &registry) const override {
508 registry.insert<vector::VectorDialect>();
509 }
510};
511
512} // namespace
513
515 RewritePatternSet &patterns, VectorMultiReductionLowering options,
516 PatternBenefit benefit) {
517 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
518 patterns.getContext(), options, benefit);
519 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
520 if (options == VectorMultiReductionLowering ::InnerReduction)
521 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
522 benefit);
523 else
524 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
525 benefit);
526}
527
529 vector::VectorMultiReductionLowering option) {
530 return std::make_unique<LowerVectorMultiReductionPass>(option);
531}
return success()
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...