22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
45 struct BreakDownSubgroupReduce final :
OpRewritePattern<gpu::SubgroupReduceOp> {
46 BreakDownSubgroupReduce(
MLIRContext *ctx,
unsigned maxShuffleBitwidth,
51 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
53 auto vecTy = dyn_cast<VectorType>(op.getType());
54 if (!vecTy || vecTy.getNumElements() < 2)
57 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
58 assert(!vecTy.isScalable() &&
"Unexpected vector type");
60 Type elemTy = vecTy.getElementType();
62 if (elemBitwidth >= maxShuffleBitwidth)
64 op, llvm::formatv(
"element type too large ({0}), cannot break down "
65 "into vectors of bitwidth {1} or less",
66 elemBitwidth, maxShuffleBitwidth));
68 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
69 assert(elementsPerShuffle >= 1);
71 unsigned numNewReductions =
73 assert(numNewReductions >= 1);
74 if (numNewReductions == 1)
81 for (
unsigned i = 0; i != numNewReductions; ++i) {
82 int64_t startIdx = i * elementsPerShuffle;
84 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
85 int64_t numElems = endIdx - startIdx;
90 rewriter.
create<vector::ExtractOp>(loc, op.getValue(), startIdx);
92 extracted = rewriter.
create<vector::ExtractStridedSliceOp>(
93 loc, op.getValue(), startIdx, numElems,
98 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
99 op.getClusterStride());
101 res = rewriter.
create<vector::InsertOp>(loc,
reduce, res, startIdx);
105 res = rewriter.
create<vector::InsertStridedSliceOp>(
106 loc,
reduce, res, startIdx, 1);
114 unsigned maxShuffleBitwidth = 0;
125 struct ScalarizeSingleElementReduce final
129 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
131 auto vecTy = dyn_cast<VectorType>(op.getType());
132 if (!vecTy || vecTy.getNumElements() != 1)
135 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
136 assert(!vecTy.isScalable() &&
"Unexpected vector type");
138 Value extracted = rewriter.
create<vector::ExtractOp>(loc, op.getValue(), 0);
140 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
141 op.getClusterStride());
148 unsigned clusterStride;
149 unsigned clusterSize;
150 unsigned subgroupSize;
153 static FailureOr<ClusterInfo>
154 getAndValidateClusterInfo(gpu::SubgroupReduceOp op,
unsigned subgroupSize) {
155 assert(llvm::isPowerOf2_32(subgroupSize));
157 std::optional<uint32_t> clusterSize = op.getClusterSize();
158 assert(!clusterSize ||
159 llvm::isPowerOf2_32(*clusterSize));
160 if (clusterSize && *clusterSize > subgroupSize)
161 return op.emitOpError()
162 <<
"cluster size " << *clusterSize
163 <<
" is greater than subgroup size " << subgroupSize;
164 unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
166 auto clusterStride = op.getClusterStride();
167 assert(llvm::isPowerOf2_32(clusterStride));
168 if (clusterStride >= subgroupSize)
169 return op.emitOpError()
170 <<
"cluster stride " << clusterStride
171 <<
" is not less than subgroup size " << subgroupSize;
173 return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
185 Value input, gpu::AllReduceOperation mode,
186 const ClusterInfo &ci,
191 Value laneVal = input;
193 for (
unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
195 Value shuffled = builder
196 .
create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
198 gpu::ShuffleMode::XOR)
202 laneVal, unpackFn(shuffled));
210 struct ScalarSubgroupReduceToShuffles final
212 ScalarSubgroupReduceToShuffles(
MLIRContext *ctx,
unsigned subgroupSize,
213 unsigned shuffleBitwidth,
bool matchClustered,
216 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
218 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
220 if (op.getClusterSize().has_value() != matchClustered) {
222 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
223 "only match {1}clustered ops",
224 matchClustered ?
"non-" :
"",
225 matchClustered ?
"" :
"non-"));
228 auto ci = getAndValidateClusterInfo(op, subgroupSize);
232 Type valueTy = op.getType();
233 unsigned elemBitwidth =
235 if (!valueTy.
isIntOrFloat() || elemBitwidth > shuffleBitwidth)
237 op,
"value type is not a compatible scalar");
241 if (elemBitwidth == shuffleBitwidth) {
242 auto identityFn = [](
Value v) {
return v; };
243 rewriter.
replaceOp(op, createSubgroupShuffleReduction(
244 rewriter, loc, op.getValue(), op.getOp(), *ci,
245 identityFn, identityFn));
251 auto packFn = [loc, &rewriter, equivIntType,
254 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
255 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
257 auto unpackFn = [loc, &rewriter, equivIntType,
260 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
261 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
265 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
266 op.getOp(), *ci, packFn, unpackFn));
271 unsigned subgroupSize = 0;
272 unsigned shuffleBitwidth = 0;
273 bool matchClustered =
false;
277 struct VectorSubgroupReduceToShuffles final
279 VectorSubgroupReduceToShuffles(
MLIRContext *ctx,
unsigned subgroupSize,
280 unsigned shuffleBitwidth,
bool matchClustered,
283 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
285 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
287 if (op.getClusterSize().has_value() != matchClustered) {
289 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
290 "only match {1}clustered ops",
291 matchClustered ?
"non-" :
"",
292 matchClustered ?
"" :
"non-"));
295 auto ci = getAndValidateClusterInfo(op, subgroupSize);
299 auto vecTy = dyn_cast<VectorType>(op.getType());
303 unsigned vecBitwidth =
304 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
305 if (vecBitwidth > shuffleBitwidth)
308 llvm::formatv(
"vector type bitwidth too large ({0}), cannot lower "
309 "to shuffles of size {1}",
310 vecBitwidth, shuffleBitwidth));
312 unsigned elementsPerShuffle =
313 shuffleBitwidth / vecTy.getElementTypeBitWidth();
314 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
316 op,
"shuffle bitwidth is not a multiple of the element bitwidth");
323 static_cast<int64_t
>(elementsPerShuffle), vecTy.getElementType());
324 Value extendedInput = op.getValue();
325 if (vecBitwidth < shuffleBitwidth) {
326 auto zero = rewriter.
create<arith::ConstantOp>(
328 extendedInput = rewriter.
create<vector::InsertStridedSliceOp>(
329 loc, extendedInput, zero, 0, 1);
335 auto packFn = [loc, &rewriter, shuffleVecType](
Value unpackedVal) ->
Value {
337 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
338 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
340 auto unpackFn = [loc, &rewriter, shuffleVecType,
343 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
344 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
347 Value res = createSubgroupShuffleReduction(
348 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
350 if (vecBitwidth < shuffleBitwidth) {
351 res = rewriter.create<vector::ExtractStridedSliceOp>(
352 loc, res, 0, vecTy.getNumElements(),
356 rewriter.replaceOp(op, res);
361 unsigned subgroupSize = 0;
362 unsigned shuffleBitwidth = 0;
363 bool matchClustered =
false;
371 maxShuffleBitwidth, benefit);
372 patterns.
add<ScalarizeSingleElementReduce>(patterns.
getContext(), benefit);
378 patterns.
add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
379 patterns.
getContext(), subgroupSize, shuffleBitwidth,
386 patterns.
add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387 patterns.
getContext(), subgroupSize, shuffleBitwidth,
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...