24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "vector-multi-reduction"
39 class InnerOuterDimReductionConversion
44 explicit InnerOuterDimReductionConversion(
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
58 if (maskableOp.isMasked()) {
60 rootOp = maskableOp.getMaskingOp();
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.
getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
70 auto reductionDimsRange =
71 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
72 auto reductionDims = llvm::to_vector<4>(llvm::map_range(
73 reductionDimsRange, [](
const APInt &a) {
return a.getZExtValue(); }));
74 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
76 int64_t reductionSize = reductionDims.size();
78 for (int64_t i = 0; i < srcRank; ++i)
79 if (!reductionDimsSet.contains(i))
80 parallelDims.push_back(i);
84 if (parallelDims.empty())
86 if (useInnerDimsForReduction &&
88 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
91 if (!useInnerDimsForReduction &&
92 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
94 parallelDims.size() + reductionDims.size()))))
98 if (useInnerDimsForReduction) {
99 indices.append(parallelDims.begin(), parallelDims.end());
100 indices.append(reductionDims.begin(), reductionDims.end());
102 indices.append(reductionDims.begin(), reductionDims.end());
103 indices.append(parallelDims.begin(), parallelDims.end());
107 Value transposedMask;
108 if (maskableOp.isMasked()) {
109 transposedMask = rewriter.
create<vector::TransposeOp>(
110 loc, maskableOp.getMaskingOp().getMask(), indices);
114 auto transposeOp = rewriter.
create<vector::TransposeOp>(loc, src, indices);
116 for (
int i = 0; i < reductionSize; ++i) {
117 if (useInnerDimsForReduction)
118 reductionMask[srcRank - i - 1] =
true;
120 reductionMask[i] =
true;
123 Operation *newMultiRedOp = rewriter.
create<vector::MultiDimReductionOp>(
124 multiReductionOp.getLoc(), transposeOp.getResult(),
125 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
134 const bool useInnerDimsForReduction;
139 class ReduceMultiDimReductionRank
144 explicit ReduceMultiDimReductionRank(
148 useInnerDimsForReduction(
149 options == vector::VectorMultiReductionLowering::InnerReduction) {}
151 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
156 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
158 if (maskableOp.isMasked()) {
160 rootOp = maskableOp.getMaskingOp();
162 rootOp = multiReductionOp;
165 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
166 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
167 auto srcScalableDims =
168 multiReductionOp.getSourceVectorType().getScalableDims();
169 auto loc = multiReductionOp.
getLoc();
177 if (llvm::count(srcScalableDims,
true) > 1)
182 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
189 bool isReductionDimScalable =
false;
191 int64_t i = it.index();
192 bool isReduction = it.value();
194 reductionDims.push_back(i);
195 reductionShapes.push_back(srcShape[i]);
196 isReductionDimScalable |= srcScalableDims[i];
198 parallelDims.push_back(i);
199 parallelShapes.push_back(srcShape[i]);
200 parallelScalableDims.push_back(srcScalableDims[i]);
205 int flattenedParallelDim = 0;
206 int flattenedReductionDim = 0;
207 if (!parallelShapes.empty()) {
208 flattenedParallelDim = 1;
209 for (
auto d : parallelShapes)
210 flattenedParallelDim *= d;
212 if (!reductionShapes.empty()) {
213 flattenedReductionDim = 1;
214 for (
auto d : reductionShapes)
215 flattenedReductionDim *= d;
218 assert((flattenedParallelDim || flattenedReductionDim) &&
219 "expected at least one parallel or reduction dim");
224 if (useInnerDimsForReduction &&
225 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
228 counter = reductionDims.size();
229 if (!useInnerDimsForReduction &&
230 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
238 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
239 if (flattenedParallelDim) {
240 mask.push_back(
false);
242 scalableDims.push_back(isParallelDimScalable);
244 if (flattenedReductionDim) {
245 mask.push_back(
true);
247 scalableDims.push_back(isReductionDimScalable);
249 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
250 std::swap(mask.front(), mask.back());
252 std::swap(scalableDims.front(), scalableDims.back());
256 if (maskableOp.isMasked()) {
257 Value vectorMask = maskableOp.getMaskingOp().getMask();
260 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
262 rewriter.
create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
266 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
269 loc, castedType, multiReductionOp.getSource());
271 Value acc = multiReductionOp.getAcc();
272 if (flattenedParallelDim) {
274 {flattenedParallelDim},
276 {isParallelDimScalable});
277 acc = rewriter.
create<vector::ShapeCastOp>(loc, accType, acc);
281 Operation *newMultiDimRedOp = rewriter.
create<vector::MultiDimReductionOp>(
282 loc, cast, acc, mask, multiReductionOp.getKind());
288 if (parallelShapes.empty()) {
295 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
296 parallelScalableDims);
298 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
303 const bool useInnerDimsForReduction;
308 struct TwoDimMultiReductionToElementWise
312 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
315 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
316 if (maskableOp.isMasked())
320 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
325 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
328 auto loc = multiReductionOp.getLoc();
330 multiReductionOp.getSourceVectorType().getShape();
336 Value result = multiReductionOp.getAcc();
337 for (int64_t i = 0; i < srcShape[0]; i++) {
338 auto operand = rewriter.
create<vector::ExtractOp>(
339 loc, multiReductionOp.getSource(), i);
344 rewriter.
replaceOp(multiReductionOp, result);
351 struct TwoDimMultiReductionToReduction
355 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
357 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
361 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
367 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
369 if (maskableOp.isMasked()) {
371 rootOp = maskableOp.getMaskingOp();
373 rootOp = multiReductionOp;
376 auto loc = multiReductionOp.
getLoc();
378 loc, multiReductionOp.getDestType(),
379 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
380 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
382 for (
int i = 0; i < outerDim; ++i) {
383 auto v = rewriter.
create<vector::ExtractOp>(
385 auto acc = rewriter.
create<vector::ExtractOp>(
388 loc, multiReductionOp.getKind(), v, acc);
391 if (maskableOp.isMasked()) {
397 result = rewriter.
create<vector::InsertElementOp>(
398 loc, reductionOp->getResult(0), result,
399 rewriter.
create<arith::ConstantIndexOp>(loc, i));
412 struct OneDimMultiReductionToTwoDim
416 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
418 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
426 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
429 if (maskableOp.isMasked()) {
431 rootOp = maskableOp.getMaskingOp();
432 mask = maskableOp.getMaskingOp().getMask();
434 rootOp = multiReductionOp;
437 auto loc = multiReductionOp.
getLoc();
438 auto srcVectorType = multiReductionOp.getSourceVectorType();
439 auto srcShape = srcVectorType.getShape();
446 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
447 "multi_reduction with a single dimension expects a scalar result");
455 loc, castedType, multiReductionOp.getSource());
456 Value castAcc = rewriter.
create<vector::BroadcastOp>(
457 loc, accType, multiReductionOp.getAcc());
459 if (maskableOp.isMasked()) {
460 auto maskType = llvm::cast<VectorType>(mask.
getType());
463 maskType.getElementType(),
465 castMask = rewriter.
create<vector::BroadcastOp>(loc, castMaskType, mask);
469 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
478 struct LowerVectorMultiReductionPass
479 :
public vector::impl::LowerVectorMultiReductionBase<
480 LowerVectorMultiReductionPass> {
481 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
482 this->loweringStrategy = option;
485 void runOnOperation()
override {
491 this->loweringStrategy);
498 registry.
insert<vector::VectorDialect>();
507 patterns.
add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
509 patterns.
add<OneDimMultiReductionToTwoDim>(patterns.
getContext(), benefit);
510 if (
options == VectorMultiReductionLowering ::InnerReduction)
511 patterns.
add<TwoDimMultiReductionToReduction>(patterns.
getContext(),
514 patterns.
add<TwoDimMultiReductionToElementWise>(patterns.
getContext(),
519 vector::VectorMultiReductionLowering option) {
520 return std::make_unique<LowerVectorMultiReductionPass>(option);
static llvm::ManagedStatic< PassManagerOptions > options
static VectorShape vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
TypedAttr getZeroAttr(Type type)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::unique_ptr< Pass > createLowerVectorMultiReductionPass(VectorMultiReductionLowering option=VectorMultiReductionLowering::InnerParallel)
Creates an instance of the vector.multi_reduction lowering pass.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...