24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "vector-multi-reduction"
39 class InnerOuterDimReductionConversion
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.
getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
73 int64_t reductionSize = reductionDims.size();
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
81 if (parallelDims.empty())
83 if (useInnerDimsForReduction &&
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
91 parallelDims.size() + reductionDims.size()))))
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = rewriter.
create<vector::TransposeOp>(
107 loc, maskableOp.getMaskingOp().getMask(), indices);
111 auto transposeOp = rewriter.
create<vector::TransposeOp>(loc, src, indices);
113 for (
int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] =
true;
117 reductionMask[i] =
true;
120 Operation *newMultiRedOp = rewriter.
create<vector::MultiDimReductionOp>(
121 multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
131 const bool useInnerDimsForReduction;
136 class ReduceMultiDimReductionRank
141 explicit ReduceMultiDimReductionRank(
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
155 if (maskableOp.isMasked()) {
157 rootOp = maskableOp.getMaskingOp();
159 rootOp = multiReductionOp;
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.
getLoc();
174 if (llvm::count(srcScalableDims,
true) > 1)
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
186 bool isReductionDimScalable =
false;
188 int64_t i = it.index();
189 bool isReduction = it.value();
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (
auto d : parallelShapes)
207 flattenedParallelDim *= d;
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (
auto d : reductionShapes)
212 flattenedReductionDim *= d;
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
236 if (flattenedParallelDim) {
237 mask.push_back(
false);
239 scalableDims.push_back(isParallelDimScalable);
241 if (flattenedReductionDim) {
242 mask.push_back(
true);
244 scalableDims.push_back(isReductionDimScalable);
246 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
249 std::swap(scalableDims.front(), scalableDims.back());
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
257 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
259 rewriter.
create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
266 loc, castedType, multiReductionOp.getSource());
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
271 {flattenedParallelDim},
273 {isParallelDimScalable});
274 acc = rewriter.
create<vector::ShapeCastOp>(loc, accType, acc);
278 Operation *newMultiDimRedOp = rewriter.
create<vector::MultiDimReductionOp>(
279 loc, cast, acc, mask, multiReductionOp.getKind());
285 if (parallelShapes.empty()) {
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
295 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
300 const bool useInnerDimsForReduction;
305 struct TwoDimMultiReductionToElementWise
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
319 auto loc = multiReductionOp.getLoc();
321 multiReductionOp.getSourceVectorType().getShape();
329 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
331 Value mask =
nullptr;
332 if (maskableOp.isMasked()) {
334 rootOp = maskableOp.getMaskingOp();
335 mask = maskableOp.getMaskingOp().getMask();
337 rootOp = multiReductionOp;
340 Value result = multiReductionOp.getAcc();
341 for (int64_t i = 0; i < srcShape[0]; i++) {
342 auto operand = rewriter.
create<vector::ExtractOp>(
343 loc, multiReductionOp.getSource(), i);
344 Value extractMask =
nullptr;
346 extractMask = rewriter.
create<vector::ExtractOp>(loc, mask, i);
350 result,
nullptr, extractMask);
360 struct TwoDimMultiReductionToReduction
364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
376 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
378 if (maskableOp.isMasked()) {
380 rootOp = maskableOp.getMaskingOp();
382 rootOp = multiReductionOp;
385 auto loc = multiReductionOp.
getLoc();
387 loc, multiReductionOp.getDestType(),
388 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
391 for (
int i = 0; i < outerDim; ++i) {
392 auto v = rewriter.
create<vector::ExtractOp>(
394 auto acc = rewriter.
create<vector::ExtractOp>(
397 loc, multiReductionOp.getKind(), v, acc);
400 if (maskableOp.isMasked()) {
406 result = rewriter.
create<vector::InsertOp>(loc, reductionOp->getResult(0),
420 struct OneDimMultiReductionToTwoDim
424 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
426 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
434 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
437 if (maskableOp.isMasked()) {
439 rootOp = maskableOp.getMaskingOp();
440 mask = maskableOp.getMaskingOp().getMask();
442 rootOp = multiReductionOp;
445 auto loc = multiReductionOp.
getLoc();
446 auto srcVectorType = multiReductionOp.getSourceVectorType();
447 auto srcShape = srcVectorType.getShape();
454 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
455 "multi_reduction with a single dimension expects a scalar result");
463 loc, castedType, multiReductionOp.getSource());
464 Value castAcc = rewriter.
create<vector::BroadcastOp>(
465 loc, accType, multiReductionOp.getAcc());
467 if (maskableOp.isMasked()) {
468 auto maskType = llvm::cast<VectorType>(mask.
getType());
471 maskType.getElementType(),
473 castMask = rewriter.
create<vector::BroadcastOp>(loc, castMaskType, mask);
477 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
486 struct LowerVectorMultiReductionPass
487 :
public vector::impl::LowerVectorMultiReductionBase<
488 LowerVectorMultiReductionPass> {
489 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
490 this->loweringStrategy = option;
493 void runOnOperation()
override {
499 this->loweringStrategy);
506 registry.
insert<vector::VectorDialect>();
515 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
518 if (
options == VectorMultiReductionLowering ::InnerReduction)
527 vector::VectorMultiReductionLowering option) {
528 return std::make_unique<LowerVectorMultiReductionPass>(option);
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...