24#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
29#define DEBUG_TYPE "vector-multi-reduction"
39class InnerOuterDimReductionConversion
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.
getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
73 int64_t reductionSize = reductionDims.size();
75 for (
int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
81 if (parallelDims.empty())
83 if (useInnerDimsForReduction &&
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
91 parallelDims.size() + reductionDims.size()))))
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = vector::TransposeOp::create(
107 rewriter, loc, maskableOp.getMaskingOp().getMask(),
indices);
111 auto transposeOp = vector::TransposeOp::create(rewriter, loc, src,
indices);
113 for (
int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] =
true;
117 reductionMask[i] =
true;
120 Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
121 rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
131 const bool useInnerDimsForReduction;
147class FlattenMultiReduction
152 explicit FlattenMultiReduction(
MLIRContext *context,
153 vector::VectorMultiReductionLowering
options,
156 useInnerDimsForReduction(
157 options == vector::VectorMultiReductionLowering::InnerReduction) {}
159 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
164 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
166 if (maskableOp.isMasked()) {
168 rootOp = maskableOp.getMaskingOp();
170 rootOp = multiReductionOp;
173 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
174 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
175 auto srcScalableDims =
176 multiReductionOp.getSourceVectorType().getScalableDims();
177 auto loc = multiReductionOp.
getLoc();
185 if (llvm::count(srcScalableDims,
true) > 1)
190 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
197 bool isReductionDimScalable =
false;
198 for (
const auto &it : llvm::enumerate(reductionMask)) {
200 bool isReduction = it.value();
202 reductionDims.push_back(i);
203 reductionShapes.push_back(srcShape[i]);
204 isReductionDimScalable |= srcScalableDims[i];
206 parallelDims.push_back(i);
207 parallelShapes.push_back(srcShape[i]);
208 parallelScalableDims.push_back(srcScalableDims[i]);
213 int flattenedParallelDim = 0;
214 int flattenedReductionDim = 0;
215 if (!parallelShapes.empty()) {
216 flattenedParallelDim = 1;
217 for (
auto d : parallelShapes)
218 flattenedParallelDim *= d;
220 if (!reductionShapes.empty()) {
221 flattenedReductionDim = 1;
222 for (
auto d : reductionShapes)
223 flattenedReductionDim *= d;
226 assert((flattenedParallelDim || flattenedReductionDim) &&
227 "expected at least one parallel or reduction dim");
232 if (useInnerDimsForReduction &&
233 llvm::any_of(parallelDims, [&](
int64_t i) {
return i != counter++; }))
236 counter = reductionDims.size();
237 if (!useInnerDimsForReduction &&
238 llvm::any_of(parallelDims, [&](
int64_t i) {
return i != counter++; }))
246 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
247 if (flattenedParallelDim) {
248 mask.push_back(
false);
250 scalableDims.push_back(isParallelDimScalable);
252 if (flattenedReductionDim) {
253 mask.push_back(
true);
255 scalableDims.push_back(isReductionDimScalable);
257 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
258 std::swap(mask.front(), mask.back());
260 std::swap(scalableDims.front(), scalableDims.back());
264 if (maskableOp.isMasked()) {
265 Value vectorMask = maskableOp.getMaskingOp().getMask();
266 auto maskCastedType = VectorType::get(
268 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
269 newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
273 auto castedType = VectorType::get(
274 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
276 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
277 multiReductionOp.getSource());
279 Value acc = multiReductionOp.getAcc();
280 if (flattenedParallelDim) {
281 auto accType = VectorType::get(
282 {flattenedParallelDim},
284 {isParallelDimScalable});
285 acc = vector::ShapeCastOp::create(rewriter, loc, accType,
acc);
289 Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
290 rewriter, loc, cast,
acc, mask, multiReductionOp.getKind());
296 if (parallelShapes.empty()) {
303 VectorType outputCastedType = VectorType::get(
304 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
305 parallelScalableDims);
307 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
312 const bool useInnerDimsForReduction;
341struct TwoDimMultiReductionToElementWise
343 using MaskableOpRewritePattern::MaskableOpRewritePattern;
346 matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
347 vector::MaskingOpInterface maskingOp,
349 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
354 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
357 Value mask = maskingOp ? maskingOp.getMask() :
Value();
359 auto loc = multiReductionOp.getLoc();
360 Value source = multiReductionOp.getSource();
362 multiReductionOp.getSourceVectorType().getShape();
363 int outerDim = srcShape[0];
366 for (
int64_t i = 0; i < outerDim; i++) {
367 auto v = vector::ExtractOp::create(rewriter, loc, source, i);
368 Value m = mask ?
Value(vector::ExtractOp::create(rewriter, loc, mask, i))
370 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), v,
398struct TwoDimMultiReductionToReduction
400 using MaskableOpRewritePattern::MaskableOpRewritePattern;
403 matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
404 vector::MaskingOpInterface maskingOp,
406 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
411 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
414 Value mask = maskingOp ? maskingOp.getMask() :
nullptr;
416 auto loc = multiReductionOp.
getLoc();
417 Value source = multiReductionOp.getSource();
418 Value acc = multiReductionOp.getAcc();
419 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
422 rewriter, loc, multiReductionOp.getDestType(),
423 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
426 for (
int64_t i = 0; i < outerDim; ++i) {
427 Value v = vector::ExtractOp::create(rewriter, loc, source, i);
428 Value a = vector::ExtractOp::create(rewriter, loc,
acc, i);
430 Operation *reductionOp = vector::ReductionOp::create(
431 rewriter, loc, multiReductionOp.getKind(), v, a);
434 Value m = vector::ExtractOp::create(rewriter, loc, mask, i);
438 result = vector::InsertOp::create(rewriter, loc,
456struct OneDimMultiReductionToReduction
458 using MaskableOpRewritePattern::MaskableOpRewritePattern;
461 matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
462 vector::MaskingOpInterface maskingOp,
464 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
468 if (!multiReductionOp.isReducedDim(0))
471 auto loc = multiReductionOp.getLoc();
472 Value mask = maskingOp ? maskingOp.getMask() :
Value();
474 Operation *reductionOp = vector::ReductionOp::create(
475 rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
476 multiReductionOp.getAcc());
485struct LowerVectorMultiReductionPass
486 :
public vector::impl::LowerVectorMultiReductionBase<
487 LowerVectorMultiReductionPass> {
488 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
489 this->loweringStrategy = option;
492 void runOnOperation()
override {
498 patterns, this->loweringStrategy);
504 flatteningPatterns, this->loweringStrategy);
510 unrollingPatterns, this->loweringStrategy);
516 registry.
insert<vector::VectorDialect>();
538 patterns.
add<OneDimMultiReductionToReduction>(patterns.
getContext(), benefit);
539 if (
options == VectorMultiReductionLowering ::InnerReduction)
540 patterns.
add<TwoDimMultiReductionToReduction>(patterns.
getContext(),
543 patterns.
add<TwoDimMultiReductionToElementWise>(patterns.
getContext(),
548 vector::VectorMultiReductionLowering option) {
549 return std::make_unique<LowerVectorMultiReductionPass>(option);
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionFlatteningPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMultiReductionReorderPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.