19 #define DEBUG_TYPE "vector-multi-reduction"
29 class InnerOuterDimReductionConversion
34 explicit InnerOuterDimReductionConversion(
38 useInnerDimsForReduction(
39 options == vector::VectorMultiReductionLowering::InnerReduction) {}
41 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
46 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
48 if (maskableOp.isMasked()) {
50 rootOp = maskableOp.getMaskingOp();
52 rootOp = multiReductionOp;
55 auto src = multiReductionOp.getSource();
56 auto loc = multiReductionOp.
getLoc();
57 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
60 auto reductionDimsRange =
61 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
62 auto reductionDims = llvm::to_vector<4>(llvm::map_range(
63 reductionDimsRange, [](
const APInt &a) {
return a.getZExtValue(); }));
64 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
66 int64_t reductionSize = reductionDims.size();
68 for (int64_t i = 0; i < srcRank; ++i)
69 if (!reductionDimsSet.contains(i))
70 parallelDims.push_back(i);
74 if (parallelDims.empty())
76 if (useInnerDimsForReduction &&
78 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
81 if (!useInnerDimsForReduction &&
82 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
84 parallelDims.size() + reductionDims.size()))))
88 if (useInnerDimsForReduction) {
89 indices.append(parallelDims.begin(), parallelDims.end());
90 indices.append(reductionDims.begin(), reductionDims.end());
92 indices.append(reductionDims.begin(), reductionDims.end());
93 indices.append(parallelDims.begin(), parallelDims.end());
98 if (maskableOp.isMasked()) {
99 transposedMask = rewriter.
create<vector::TransposeOp>(
100 loc, maskableOp.getMaskingOp().getMask(), indices);
104 auto transposeOp = rewriter.
create<vector::TransposeOp>(loc, src, indices);
106 for (
int i = 0; i < reductionSize; ++i) {
107 if (useInnerDimsForReduction)
108 reductionMask[srcRank - i - 1] =
true;
110 reductionMask[i] =
true;
113 Operation *newMultiRedOp = rewriter.
create<vector::MultiDimReductionOp>(
114 multiReductionOp.getLoc(), transposeOp.getResult(),
115 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
124 const bool useInnerDimsForReduction;
129 class ReduceMultiDimReductionRank
134 explicit ReduceMultiDimReductionRank(
138 useInnerDimsForReduction(
139 options == vector::VectorMultiReductionLowering::InnerReduction) {}
141 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
146 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
148 if (maskableOp.isMasked()) {
150 rootOp = maskableOp.getMaskingOp();
152 rootOp = multiReductionOp;
155 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
156 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
157 auto srcScalableDims =
158 multiReductionOp.getSourceVectorType().getScalableDims();
159 auto loc = multiReductionOp.
getLoc();
167 if (llvm::count(srcScalableDims,
true) > 1)
172 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
179 bool isReductionDimScalable =
false;
181 int64_t i = it.index();
182 bool isReduction = it.value();
184 reductionDims.push_back(i);
185 reductionShapes.push_back(srcShape[i]);
186 isReductionDimScalable |= srcScalableDims[i];
188 parallelDims.push_back(i);
189 parallelShapes.push_back(srcShape[i]);
190 parallelScalableDims.push_back(srcScalableDims[i]);
195 int flattenedParallelDim = 0;
196 int flattenedReductionDim = 0;
197 if (!parallelShapes.empty()) {
198 flattenedParallelDim = 1;
199 for (
auto d : parallelShapes)
200 flattenedParallelDim *= d;
202 if (!reductionShapes.empty()) {
203 flattenedReductionDim = 1;
204 for (
auto d : reductionShapes)
205 flattenedReductionDim *= d;
208 assert((flattenedParallelDim || flattenedReductionDim) &&
209 "expected at least one parallel or reduction dim");
214 if (useInnerDimsForReduction &&
215 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
218 counter = reductionDims.size();
219 if (!useInnerDimsForReduction &&
220 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
228 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims,
true);
229 if (flattenedParallelDim) {
230 mask.push_back(
false);
232 scalableDims.push_back(isParallelDimScalable);
234 if (flattenedReductionDim) {
235 mask.push_back(
true);
237 scalableDims.push_back(isReductionDimScalable);
239 if (!useInnerDimsForReduction &&
vectorShape.size() == 2) {
240 std::swap(mask.front(), mask.back());
242 std::swap(scalableDims.front(), scalableDims.back());
246 if (maskableOp.isMasked()) {
247 Value vectorMask = maskableOp.getMaskingOp().getMask();
250 llvm::cast<VectorType>(vectorMask.
getType()).getElementType());
252 rewriter.
create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
256 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
259 loc, castedType, multiReductionOp.getSource());
261 Value acc = multiReductionOp.getAcc();
262 if (flattenedParallelDim) {
264 {flattenedParallelDim},
266 {isParallelDimScalable});
267 acc = rewriter.
create<vector::ShapeCastOp>(loc, accType, acc);
271 Operation *newMultiDimRedOp = rewriter.
create<vector::MultiDimReductionOp>(
272 loc, cast, acc, mask, multiReductionOp.getKind());
278 if (parallelShapes.empty()) {
285 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
286 parallelScalableDims);
288 rootOp, outputCastedType, newMultiDimRedOp->
getResult(0));
293 const bool useInnerDimsForReduction;
298 struct TwoDimMultiReductionToElementWise
302 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
305 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
306 if (maskableOp.isMasked())
310 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
315 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
318 auto loc = multiReductionOp.getLoc();
320 multiReductionOp.getSourceVectorType().getShape();
326 Value result = multiReductionOp.getAcc();
327 for (int64_t i = 0; i < srcShape[0]; i++) {
328 auto operand = rewriter.
create<vector::ExtractOp>(
329 loc, multiReductionOp.getSource(), i);
334 rewriter.
replaceOp(multiReductionOp, result);
341 struct TwoDimMultiReductionToReduction
345 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
347 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
351 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
357 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
359 if (maskableOp.isMasked()) {
361 rootOp = maskableOp.getMaskingOp();
363 rootOp = multiReductionOp;
366 auto loc = multiReductionOp.
getLoc();
368 loc, multiReductionOp.getDestType(),
369 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
370 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
372 for (
int i = 0; i < outerDim; ++i) {
373 auto v = rewriter.
create<vector::ExtractOp>(
375 auto acc = rewriter.
create<vector::ExtractOp>(
378 loc, multiReductionOp.getKind(), v, acc);
381 if (maskableOp.isMasked()) {
387 result = rewriter.
create<vector::InsertElementOp>(
388 loc, reductionOp->getResult(0), result,
389 rewriter.
create<arith::ConstantIndexOp>(loc, i));
402 struct OneDimMultiReductionToTwoDim
406 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
408 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
416 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
419 if (maskableOp.isMasked()) {
421 rootOp = maskableOp.getMaskingOp();
422 mask = maskableOp.getMaskingOp().getMask();
424 rootOp = multiReductionOp;
427 auto loc = multiReductionOp.
getLoc();
428 auto srcVectorType = multiReductionOp.getSourceVectorType();
429 auto srcShape = srcVectorType.getShape();
431 srcVectorType.getElementType());
434 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
435 "multi_reduction with a single dimension expects a scalar result");
443 loc, castedType, multiReductionOp.getSource());
444 Value castAcc = rewriter.
create<vector::BroadcastOp>(
445 loc, accType, multiReductionOp.getAcc());
447 if (maskableOp.isMasked()) {
448 auto maskType = llvm::cast<ShapedType>(mask.
getType());
451 maskType.getElementType());
452 castMask = rewriter.
create<vector::BroadcastOp>(loc, castMaskType, mask);
456 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
469 patterns.
add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
471 patterns.
add<OneDimMultiReductionToTwoDim>(patterns.
getContext(), benefit);
472 if (
options == VectorMultiReductionLowering ::InnerReduction)
473 patterns.
add<TwoDimMultiReductionToReduction>(patterns.
getContext(),
476 patterns.
add<TwoDimMultiReductionToElementWise>(patterns.
getContext(),
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit=1)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector....
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...