24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "vector-multi-reduction"
39 class InnerOuterDimReductionConversion
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.
getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
73 int64_t reductionSize = reductionDims.size();
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
81 if (parallelDims.empty())
83 if (useInnerDimsForReduction &&
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
91 parallelDims.size() + reductionDims.size()))))
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = vector::TransposeOp::create(
107 rewriter, loc, maskableOp.getMaskingOp().getMask(), indices);
111 auto transposeOp = vector::TransposeOp::create(rewriter, loc, src, indices);
113 for (
int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] =
true;
117 reductionMask[i] =
true;
120 Operation *newMultiRedOp = vector::MultiDimReductionOp::create(
121 rewriter, multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
131 const bool useInnerDimsForReduction;
136 class ReduceMultiDimReductionRank
141 explicit ReduceMultiDimReductionRank(
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
155 if (maskableOp.isMasked()) {
157 rootOp = maskableOp.getMaskingOp();
159 rootOp = multiReductionOp;
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.
getLoc();
174 if (llvm::count(srcScalableDims,
true) > 1)
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
186 bool isReductionDimScalable =
false;
188 int64_t i = it.index();
189 bool isReduction = it.value();
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (
auto d : parallelShapes)
207 flattenedParallelDim *= d;
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (
auto d : reductionShapes)
212 flattenedReductionDim *= d;
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
236 if (flattenedParallelDim) {
237 mask.push_back(
false);
239 scalableDims.push_back(isParallelDimScalable);
241 if (flattenedReductionDim) {
242 mask.push_back(
true);
244 scalableDims.push_back(isReductionDimScalable);
246 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
249 std::swap(scalableDims.front(), scalableDims.back());
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
257 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
258 newVectorMask = vector::ShapeCastOp::create(rewriter, loc, maskCastedType,
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
265 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
266 multiReductionOp.getSource());
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
271 {flattenedParallelDim},
273 {isParallelDimScalable});
274 acc = vector::ShapeCastOp::create(rewriter, loc, accType, acc);
278 Operation *newMultiDimRedOp = vector::MultiDimReductionOp::create(
279 rewriter, loc, cast, acc, mask, multiReductionOp.getKind());
285 if (parallelShapes.empty()) {
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
295 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
300 const bool useInnerDimsForReduction;
305 struct TwoDimMultiReductionToElementWise
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
319 auto loc = multiReductionOp.getLoc();
321 multiReductionOp.getSourceVectorType().getShape();
329 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
331 Value mask =
nullptr;
332 if (maskableOp.isMasked()) {
334 rootOp = maskableOp.getMaskingOp();
335 mask = maskableOp.getMaskingOp().getMask();
337 rootOp = multiReductionOp;
340 Value result = multiReductionOp.getAcc();
341 for (int64_t i = 0; i < srcShape[0]; i++) {
342 auto operand = vector::ExtractOp::create(rewriter, loc,
343 multiReductionOp.getSource(), i);
344 Value extractMask =
nullptr;
346 extractMask = vector::ExtractOp::create(rewriter, loc, mask, i);
350 result,
nullptr, extractMask);
360 struct TwoDimMultiReductionToReduction
364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
376 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
378 if (maskableOp.isMasked()) {
380 rootOp = maskableOp.getMaskingOp();
382 rootOp = multiReductionOp;
385 auto loc = multiReductionOp.
getLoc();
386 Value result = arith::ConstantOp::create(
387 rewriter, loc, multiReductionOp.getDestType(),
388 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
391 for (
int i = 0; i < outerDim; ++i) {
392 auto v = vector::ExtractOp::create(
394 auto acc = vector::ExtractOp::create(
396 Operation *reductionOp = vector::ReductionOp::create(
397 rewriter, loc, multiReductionOp.getKind(), v, acc);
400 if (maskableOp.isMasked()) {
401 Value mask = vector::ExtractOp::create(
402 rewriter, loc, maskableOp.getMaskingOp().getMask(),
407 result = vector::InsertOp::create(rewriter, loc,
421 struct OneDimMultiReductionToTwoDim
425 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
427 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
435 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
438 if (maskableOp.isMasked()) {
440 rootOp = maskableOp.getMaskingOp();
441 mask = maskableOp.getMaskingOp().getMask();
443 rootOp = multiReductionOp;
446 auto loc = multiReductionOp.
getLoc();
447 auto srcVectorType = multiReductionOp.getSourceVectorType();
448 auto srcShape = srcVectorType.getShape();
455 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
456 "multi_reduction with a single dimension expects a scalar result");
463 Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
464 multiReductionOp.getSource());
465 Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
466 multiReductionOp.getAcc());
468 if (maskableOp.isMasked()) {
469 auto maskType = llvm::cast<VectorType>(mask.
getType());
472 maskType.getElementType(),
474 castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
477 Operation *newOp = vector::MultiDimReductionOp::create(
478 rewriter, loc, cast, castAcc, reductionMask,
479 multiReductionOp.getKind());
488 struct LowerVectorMultiReductionPass
489 :
public vector::impl::LowerVectorMultiReductionBase<
490 LowerVectorMultiReductionPass> {
491 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
492 this->loweringStrategy = option;
495 void runOnOperation()
override {
501 this->loweringStrategy);
508 registry.
insert<vector::VectorDialect>();
517 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
520 if (
options == VectorMultiReductionLowering ::InnerReduction)
529 vector::VectorMultiReductionLowering option) {
530 return std::make_unique<LowerVectorMultiReductionPass>(option);
static Type getElementType(Type type)
Determine the element type of type.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...