24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "vector-multi-reduction"
39 class InnerOuterDimReductionConversion
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.
getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
73 int64_t reductionSize = reductionDims.size();
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
81 if (parallelDims.empty())
83 if (useInnerDimsForReduction &&
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
91 parallelDims.size() + reductionDims.size()))))
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = rewriter.
create<vector::TransposeOp>(
107 loc, maskableOp.getMaskingOp().getMask(), indices);
111 auto transposeOp = rewriter.
create<vector::TransposeOp>(loc, src, indices);
113 for (
int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] =
true;
117 reductionMask[i] =
true;
120 Operation *newMultiRedOp = rewriter.
create<vector::MultiDimReductionOp>(
121 multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
131 const bool useInnerDimsForReduction;
136 class ReduceMultiDimReductionRank
141 explicit ReduceMultiDimReductionRank(
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
155 if (maskableOp.isMasked()) {
157 rootOp = maskableOp.getMaskingOp();
159 rootOp = multiReductionOp;
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.
getLoc();
174 if (llvm::count(srcScalableDims,
true) > 1)
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
186 bool isReductionDimScalable =
false;
188 int64_t i = it.index();
189 bool isReduction = it.value();
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (
auto d : parallelShapes)
207 flattenedParallelDim *= d;
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (
auto d : reductionShapes)
212 flattenedReductionDim *= d;
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
236 if (flattenedParallelDim) {
237 mask.push_back(
false);
239 scalableDims.push_back(isParallelDimScalable);
241 if (flattenedReductionDim) {
242 mask.push_back(
true);
244 scalableDims.push_back(isReductionDimScalable);
246 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
249 std::swap(scalableDims.front(), scalableDims.back());
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
257 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
259 rewriter.
create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
266 loc, castedType, multiReductionOp.getSource());
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
271 {flattenedParallelDim},
273 {isParallelDimScalable});
274 acc = rewriter.
create<vector::ShapeCastOp>(loc, accType, acc);
278 Operation *newMultiDimRedOp = rewriter.
create<vector::MultiDimReductionOp>(
279 loc, cast, acc, mask, multiReductionOp.getKind());
285 if (parallelShapes.empty()) {
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
295 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
300 const bool useInnerDimsForReduction;
305 struct TwoDimMultiReductionToElementWise
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
312 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
313 if (maskableOp.isMasked())
317 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
322 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
325 auto loc = multiReductionOp.getLoc();
327 multiReductionOp.getSourceVectorType().getShape();
333 Value result = multiReductionOp.getAcc();
334 for (int64_t i = 0; i < srcShape[0]; i++) {
335 auto operand = rewriter.
create<vector::ExtractOp>(
336 loc, multiReductionOp.getSource(), i);
341 rewriter.
replaceOp(multiReductionOp, result);
348 struct TwoDimMultiReductionToReduction
352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
354 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
358 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
364 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
366 if (maskableOp.isMasked()) {
368 rootOp = maskableOp.getMaskingOp();
370 rootOp = multiReductionOp;
373 auto loc = multiReductionOp.
getLoc();
375 loc, multiReductionOp.getDestType(),
376 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
377 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
379 for (
int i = 0; i < outerDim; ++i) {
380 auto v = rewriter.
create<vector::ExtractOp>(
382 auto acc = rewriter.
create<vector::ExtractOp>(
385 loc, multiReductionOp.getKind(), v, acc);
388 if (maskableOp.isMasked()) {
394 result = rewriter.
create<vector::InsertOp>(loc, reductionOp->getResult(0),
408 struct OneDimMultiReductionToTwoDim
412 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
414 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
422 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
425 if (maskableOp.isMasked()) {
427 rootOp = maskableOp.getMaskingOp();
428 mask = maskableOp.getMaskingOp().getMask();
430 rootOp = multiReductionOp;
433 auto loc = multiReductionOp.
getLoc();
434 auto srcVectorType = multiReductionOp.getSourceVectorType();
435 auto srcShape = srcVectorType.getShape();
442 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
443 "multi_reduction with a single dimension expects a scalar result");
451 loc, castedType, multiReductionOp.getSource());
452 Value castAcc = rewriter.
create<vector::BroadcastOp>(
453 loc, accType, multiReductionOp.getAcc());
455 if (maskableOp.isMasked()) {
456 auto maskType = llvm::cast<VectorType>(mask.
getType());
459 maskType.getElementType(),
461 castMask = rewriter.
create<vector::BroadcastOp>(loc, castMaskType, mask);
465 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
474 struct LowerVectorMultiReductionPass
475 :
public vector::impl::LowerVectorMultiReductionBase<
476 LowerVectorMultiReductionPass> {
477 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
478 this->loweringStrategy = option;
481 void runOnOperation()
override {
487 this->loweringStrategy);
494 registry.
insert<vector::VectorDialect>();
503 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
506 if (
options == VectorMultiReductionLowering ::InnerReduction)
515 vector::VectorMultiReductionLowering option) {
516 return std::make_unique<LowerVectorMultiReductionPass>(option);
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...