20 #define DEBUG_TYPE "vector-multi-reduction" 37 useInnerDimsForReduction(
42 auto src = multiReductionOp.getSource();
43 auto loc = multiReductionOp.getLoc();
44 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
47 auto reductionDimsRange =
48 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
49 auto reductionDims = llvm::to_vector<4>(llvm::map_range(
50 reductionDimsRange, [](
const APInt &a) {
return a.getZExtValue(); }));
51 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
53 int64_t reductionSize = reductionDims.size();
55 for (int64_t i = 0; i < srcRank; ++i)
56 if (!reductionDimsSet.contains(i))
57 parallelDims.push_back(i);
61 if (parallelDims.empty())
63 if (useInnerDimsForReduction &&
65 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
68 if (!useInnerDimsForReduction &&
70 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
74 if (useInnerDimsForReduction) {
75 indices.append(parallelDims.begin(), parallelDims.end());
76 indices.append(reductionDims.begin(), reductionDims.end());
78 indices.append(reductionDims.begin(), reductionDims.end());
79 indices.append(parallelDims.begin(), parallelDims.end());
81 auto transposeOp = rewriter.
create<vector::TransposeOp>(loc, src, indices);
83 for (
int i = 0; i < reductionSize; ++i) {
84 if (useInnerDimsForReduction)
85 reductionMask[srcRank - i - 1] =
true;
87 reductionMask[i] =
true;
90 multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
91 reductionMask, multiReductionOp.getKind());
96 const bool useInnerDimsForReduction;
109 useInnerDimsForReduction(
114 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
115 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
116 auto loc = multiReductionOp.getLoc();
124 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
131 int64_t i = it.index();
132 bool isReduction = it.value();
134 reductionDims.push_back(i);
135 reductionShapes.push_back(srcShape[i]);
137 parallelDims.push_back(i);
138 parallelShapes.push_back(srcShape[i]);
143 int flattenedParallelDim = 0;
144 int flattenedReductionDim = 0;
145 if (!parallelShapes.empty()) {
146 flattenedParallelDim = 1;
147 for (
auto d : parallelShapes)
148 flattenedParallelDim *= d;
150 if (!reductionShapes.empty()) {
151 flattenedReductionDim = 1;
152 for (
auto d : reductionShapes)
153 flattenedReductionDim *= d;
156 assert((flattenedParallelDim || flattenedReductionDim) &&
157 "expected at least one parallel or reduction dim");
162 if (useInnerDimsForReduction &&
163 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
166 counter = reductionDims.size();
167 if (!useInnerDimsForReduction &&
168 llvm::any_of(parallelDims, [&](int64_t i) {
return i != counter++; }))
175 if (flattenedParallelDim) {
176 mask.push_back(
false);
177 vectorShape.push_back(flattenedParallelDim);
179 if (flattenedReductionDim) {
180 mask.push_back(
true);
181 vectorShape.push_back(flattenedReductionDim);
183 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
184 std::swap(mask.front(), mask.back());
185 std::swap(vectorShape.front(), vectorShape.back());
187 auto castedType = VectorType::get(
188 vectorShape, multiReductionOp.getSourceVectorType().getElementType());
190 loc, castedType, multiReductionOp.getSource());
191 Value acc = multiReductionOp.getAcc();
192 if (flattenedParallelDim) {
193 auto accType = VectorType::get(
194 {flattenedParallelDim},
195 multiReductionOp.getSourceVectorType().getElementType());
196 acc = rewriter.
create<vector::ShapeCastOp>(loc, accType, acc);
200 auto newOp = rewriter.
create<vector::MultiDimReductionOp>(
201 loc, cast, acc, mask, multiReductionOp.getKind());
205 if (parallelShapes.empty()) {
206 rewriter.
replaceOp(multiReductionOp, newOp.getDest());
211 VectorType outputCastedType = VectorType::get(
213 multiReductionOp.getSourceVectorType().getElementType());
215 multiReductionOp, outputCastedType, newOp.getDest());
220 const bool useInnerDimsForReduction;
231 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
236 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
239 auto loc = multiReductionOp.getLoc();
241 multiReductionOp.getSourceVectorType().getShape();
247 Value result = multiReductionOp.getAcc();
248 for (int64_t i = 0; i < srcShape[0]; i++) {
249 auto operand = rewriter.
create<vector::ExtractOp>(
250 loc, multiReductionOp.getSource(), i);
255 rewriter.
replaceOp(multiReductionOp, result);
268 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
272 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
275 auto loc = multiReductionOp.getLoc();
277 loc, multiReductionOp.getDestType(),
278 rewriter.
getZeroAttr(multiReductionOp.getDestType()));
279 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
281 for (
int i = 0; i < outerDim; ++i) {
282 auto v = rewriter.
create<vector::ExtractOp>(
284 auto acc = rewriter.
create<vector::ExtractOp>(
286 auto reducedValue = rewriter.
create<vector::ReductionOp>(
287 loc, multiReductionOp.getKind(), v, acc);
288 result = rewriter.
create<vector::InsertElementOp>(
289 loc, reducedValue, result,
292 rewriter.
replaceOp(multiReductionOp, result);
308 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
313 auto loc = multiReductionOp.getLoc();
314 auto srcVectorType = multiReductionOp.getSourceVectorType();
315 auto srcShape = srcVectorType.getShape();
317 srcVectorType.getElementType());
320 assert(!multiReductionOp.getDestType().isa<VectorType>() &&
321 "multi_reduction with a single dimension expects a scalar result");
329 loc, castedType, multiReductionOp.getSource());
330 Value castAcc = rewriter.
create<vector::BroadcastOp>(
331 loc, accType, multiReductionOp.getAcc());
332 Value reduced = rewriter.
create<vector::MultiDimReductionOp>(
333 loc, cast, castAcc, mask, multiReductionOp.getKind());
345 if (options == VectorMultiReductionLowering ::InnerReduction)
Include the generated interface declarations.
Reduces the rank of vector.multi_reduction nd -> 2d given all reduction dimensions are either inner m...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Attribute getZeroAttr(Type type)
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
ReduceMultiDimReductionRank(MLIRContext *context, vector::VectorMultiReductionLowering options)
static ArrayRef< int64_t > vectorShape(Type type)
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns, VectorMultiReductionLowering options)
Collect a set of patterns to convert vector.multi_reduction op into a sequence of vector...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
VectorMultiReductionLowering
Enum to control the lowering of vector.multi_reduction operations.
This class represents an efficient way to signal success or failure.
Converts 1d vector.multi_reduction with a single reduction dimension to a 2d form with both a single ...
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Unrolls vector.multi_reduction with outermost reductions and combines results.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This file implements the following transformations as composable atomic patterns. ...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
static llvm::ManagedStatic< PassManagerOptions > options
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Specialization of arith.constant op that returns an integer of index type.
InnerOuterDimReductionConversion(MLIRContext *context, vector::VectorMultiReductionLowering options)
MLIRContext is the top-level object for a collection of MLIR operations.
Converts 2d vector.multi_reduction with inner most reduction dimension into a sequence of vector...
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override
MLIRContext * getContext() const