24#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
29#define DEBUG_TYPE "vector-multi-reduction"
39class InnerOuterDimReductionConversion
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.
getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
73 int64_t reductionSize = reductionDims.size();
75 for (
int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
81 if (parallelDims.empty())
83 if (useInnerDimsForReduction &&
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
91 parallelDims.size() + reductionDims.size()))))
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
105 if (maskableOp.isMasked()) {
106 transposedMask = vector::TransposeOp::create(
107 rewriter, loc, maskableOp.getMaskingOp().getMask(),
indices);
111 auto transposeOp = vector::TransposeOp::create(rewriter, loc, src,
indices);
113 for (
int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] =
true;
117 reductionMask[i] =
true;
120 Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
121 rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
131 const bool useInnerDimsForReduction;
136class ReduceMultiDimReductionRank
141 explicit ReduceMultiDimReductionRank(
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
155 if (maskableOp.isMasked()) {
157 rootOp = maskableOp.getMaskingOp();
159 rootOp = multiReductionOp;
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.
getLoc();
174 if (llvm::count(srcScalableDims,
true) > 1)
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
186 bool isReductionDimScalable =
false;
187 for (
const auto &it : llvm::enumerate(reductionMask)) {
189 bool isReduction = it.value();
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (
auto d : parallelShapes)
207 flattenedParallelDim *= d;
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (
auto d : reductionShapes)
212 flattenedReductionDim *= d;
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](
int64_t i) {
return i != counter++; }))
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](
int64_t i) {
return i != counter++; }))
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
236 if (flattenedParallelDim) {
237 mask.push_back(
false);
239 scalableDims.push_back(isParallelDimScalable);
241 if (flattenedReductionDim) {
242 mask.push_back(
true);
244 scalableDims.push_back(isReductionDimScalable);
246 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
249 std::swap(scalableDims.front(), scalableDims.back());
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
255 auto maskCastedType = VectorType::get(
257 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
258 newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
262 auto castedType = VectorType::get(
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
265 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
266 multiReductionOp.getSource());
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
270 auto accType = VectorType::get(
271 {flattenedParallelDim},
273 {isParallelDimScalable});
274 acc = vector::ShapeCastOp::create(rewriter, loc, accType,
acc);
278 Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
279 rewriter, loc, cast,
acc, mask, multiReductionOp.getKind());
285 if (parallelShapes.empty()) {
291 VectorType outputCastedType = VectorType::get(
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
295 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
300 const bool useInnerDimsForReduction;
305struct TwoDimMultiReductionToElementWise
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
319 auto loc = multiReductionOp.getLoc();
321 multiReductionOp.getSourceVectorType().getShape();
329 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
331 Value mask =
nullptr;
332 if (maskableOp.isMasked()) {
334 rootOp = maskableOp.getMaskingOp();
335 mask = maskableOp.getMaskingOp().getMask();
337 rootOp = multiReductionOp;
341 for (
int64_t i = 0; i < srcShape[0]; i++) {
342 auto operand = vector::ExtractOp::create(rewriter, loc,
343 multiReductionOp.getSource(), i);
344 Value extractMask =
nullptr;
346 extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
349 makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350 result,
nullptr, extractMask);
360struct TwoDimMultiReductionToReduction
364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
365 PatternRewriter &rewriter)
const override {
366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
374 OpBuilder::InsertionGuard guard(rewriter);
376 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
378 if (maskableOp.isMasked()) {
380 rootOp = maskableOp.getMaskingOp();
382 rootOp = multiReductionOp;
385 auto loc = multiReductionOp.
getLoc();
386 Value
result = arith::ConstantOp::create(
387 rewriter, loc, multiReductionOp.getDestType(),
388 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
391 for (
int i = 0; i < outerDim; ++i) {
392 auto v = vector::ExtractOp::create(
393 rewriter, loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
394 auto acc = vector::ExtractOp::create(
395 rewriter, loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
396 Operation *reductionOp = vector::ReductionOp::create(
397 rewriter, loc, multiReductionOp.getKind(), v, acc);
400 if (maskableOp.isMasked()) {
401 Value mask = vector::ExtractOp::create(
402 rewriter, loc, maskableOp.getMaskingOp().getMask(),
403 ArrayRef<int64_t>{i});
407 result = vector::InsertOp::create(rewriter, loc,
421struct OneDimMultiReductionToTwoDim
425 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
426 PatternRewriter &rewriter)
const override {
427 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
433 OpBuilder::InsertionGuard guard(rewriter);
435 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
438 if (maskableOp.isMasked()) {
440 rootOp = maskableOp.getMaskingOp();
441 mask = maskableOp.getMaskingOp().getMask();
443 rootOp = multiReductionOp;
446 auto loc = multiReductionOp.
getLoc();
447 auto srcVectorType = multiReductionOp.getSourceVectorType();
448 auto srcShape = srcVectorType.getShape();
449 auto castedType = VectorType::get(
450 ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
451 ArrayRef<bool>{
false, srcVectorType.getScalableDims().back()});
454 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
455 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
456 "multi_reduction with a single dimension expects a scalar result");
460 SmallVector<bool, 2> reductionMask{
false,
true};
463 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
464 multiReductionOp.getSource());
465 Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
466 multiReductionOp.getAcc());
468 if (maskableOp.isMasked()) {
469 auto maskType = llvm::cast<VectorType>(mask.
getType());
470 auto castMaskType = VectorType::get(
471 ArrayRef<int64_t>{1, maskType.getShape().back()},
472 maskType.getElementType(),
473 ArrayRef<bool>{
false, maskType.getScalableDims().back()});
474 castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
477 Operation *newOp = vector::MultiDimReductionOp::create(
478 rewriter, loc, cast, castAcc, reductionMask,
479 multiReductionOp.getKind());
483 ArrayRef<int64_t>{0});
488struct LowerVectorMultiReductionPass
490 LowerVectorMultiReductionPass> {
491 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
492 this->loweringStrategy = option;
495 void runOnOperation()
override {
496 Operation *op = getOperation();
499 RewritePatternSet loweringPatterns(context);
501 this->loweringStrategy);
507 void getDependentDialects(DialectRegistry ®istry)
const override {
508 registry.
insert<vector::VectorDialect>();
517 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
520 if (
options == VectorMultiReductionLowering ::InnerReduction)
529 vector::VectorMultiReductionLowering option) {
530 return std::make_unique<LowerVectorMultiReductionPass>(option);
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
LowerVectorMultiReductionBase Base
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...