25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MathExtras.h"
49 struct BreakDownSubgroupReduce final :
OpRewritePattern<gpu::SubgroupReduceOp> {
50 BreakDownSubgroupReduce(
MLIRContext *ctx,
unsigned maxShuffleBitwidth,
55 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
57 auto vecTy = dyn_cast<VectorType>(op.getType());
58 if (!vecTy || vecTy.getNumElements() < 2)
61 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
62 assert(!vecTy.isScalable() &&
"Unexpected vector type");
64 Type elemTy = vecTy.getElementType();
66 if (elemBitwidth >= maxShuffleBitwidth)
68 op, llvm::formatv(
"element type too large ({0}), cannot break down "
69 "into vectors of bitwidth {1} or less",
70 elemBitwidth, maxShuffleBitwidth));
72 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
73 assert(elementsPerShuffle >= 1);
75 unsigned numNewReductions =
77 assert(numNewReductions >= 1);
78 if (numNewReductions == 1)
85 for (
unsigned i = 0; i != numNewReductions; ++i) {
86 int64_t startIdx = i * elementsPerShuffle;
88 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
89 int64_t numElems = endIdx - startIdx;
94 rewriter.
create<vector::ExtractOp>(loc, op.getValue(), startIdx);
96 extracted = rewriter.
create<vector::ExtractStridedSliceOp>(
97 loc, op.getValue(), startIdx, numElems,
102 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
103 op.getClusterStride());
105 res = rewriter.
create<vector::InsertOp>(loc,
reduce, res, startIdx);
109 res = rewriter.
create<vector::InsertStridedSliceOp>(
110 loc,
reduce, res, startIdx, 1);
118 unsigned maxShuffleBitwidth = 0;
129 struct ScalarizeSingleElementReduce final
133 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
135 auto vecTy = dyn_cast<VectorType>(op.getType());
136 if (!vecTy || vecTy.getNumElements() != 1)
139 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
140 assert(!vecTy.isScalable() &&
"Unexpected vector type");
142 Value extracted = rewriter.
create<vector::ExtractOp>(loc, op.getValue(), 0);
144 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
145 op.getClusterStride());
152 unsigned clusterStride;
153 unsigned clusterSize;
157 static FailureOr<ClusterInfo>
158 getAndValidateClusterInfo(gpu::SubgroupReduceOp op,
unsigned subgroupSize) {
161 std::optional<uint32_t> clusterSize = op.getClusterSize();
162 assert(!clusterSize ||
163 llvm::isPowerOf2_32(*clusterSize));
165 return op.emitOpError()
166 <<
"cluster size " << *clusterSize
168 unsigned effectiveClusterSize = clusterSize.value_or(
subgroupSize);
170 auto clusterStride = op.getClusterStride();
171 assert(llvm::isPowerOf2_32(clusterStride));
173 return op.emitOpError()
174 <<
"cluster stride " << clusterStride
177 return ClusterInfo{clusterStride, effectiveClusterSize,
subgroupSize};
189 Value input, gpu::AllReduceOperation mode,
190 const ClusterInfo &ci,
195 Value laneVal = input;
197 for (
unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
199 Value shuffled = builder
200 .
create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
202 gpu::ShuffleMode::XOR)
206 laneVal, unpackFn(shuffled));
214 struct ScalarSubgroupReduceToShuffles final
217 unsigned shuffleBitwidth,
bool matchClustered,
220 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
222 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
224 if (op.getClusterSize().has_value() != matchClustered) {
226 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
227 "only match {1}clustered ops",
228 matchClustered ?
"non-" :
"",
229 matchClustered ?
"" :
"non-"));
236 Type valueTy = op.getType();
237 unsigned elemBitwidth =
239 if (!valueTy.
isIntOrFloat() || elemBitwidth > shuffleBitwidth)
241 op,
"value type is not a compatible scalar");
245 if (elemBitwidth == shuffleBitwidth) {
246 auto identityFn = [](
Value v) {
return v; };
247 rewriter.
replaceOp(op, createSubgroupShuffleReduction(
248 rewriter, loc, op.getValue(), op.getOp(), *ci,
249 identityFn, identityFn));
255 auto packFn = [loc, &rewriter, equivIntType,
258 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
259 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
261 auto unpackFn = [loc, &rewriter, equivIntType,
264 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
265 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
269 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270 op.getOp(), *ci, packFn, unpackFn));
276 unsigned shuffleBitwidth = 0;
277 bool matchClustered =
false;
281 struct VectorSubgroupReduceToShuffles final
284 unsigned shuffleBitwidth,
bool matchClustered,
287 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
289 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
291 if (op.getClusterSize().has_value() != matchClustered) {
293 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
294 "only match {1}clustered ops",
295 matchClustered ?
"non-" :
"",
296 matchClustered ?
"" :
"non-"));
303 auto vecTy = dyn_cast<VectorType>(op.getType());
307 unsigned vecBitwidth =
308 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309 if (vecBitwidth > shuffleBitwidth)
312 llvm::formatv(
"vector type bitwidth too large ({0}), cannot lower "
313 "to shuffles of size {1}",
314 vecBitwidth, shuffleBitwidth));
316 unsigned elementsPerShuffle =
317 shuffleBitwidth / vecTy.getElementTypeBitWidth();
318 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
320 op,
"shuffle bitwidth is not a multiple of the element bitwidth");
327 static_cast<int64_t
>(elementsPerShuffle), vecTy.getElementType());
328 Value extendedInput = op.getValue();
329 if (vecBitwidth < shuffleBitwidth) {
330 auto zero = rewriter.
create<arith::ConstantOp>(
332 extendedInput = rewriter.
create<vector::InsertStridedSliceOp>(
333 loc, extendedInput, zero, 0, 1);
339 auto packFn = [loc, &rewriter, shuffleVecType](
Value unpackedVal) ->
Value {
341 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
342 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
344 auto unpackFn = [loc, &rewriter, shuffleVecType,
347 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
348 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
351 Value res = createSubgroupShuffleReduction(
352 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
354 if (vecBitwidth < shuffleBitwidth) {
355 res = rewriter.create<vector::ExtractStridedSliceOp>(
356 loc, res, 0, vecTy.getNumElements(),
360 rewriter.replaceOp(op, res);
366 unsigned shuffleBitwidth = 0;
367 bool matchClustered =
false;
370 static FailureOr<Value>
371 createSubgroupDPPReduction(
PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372 Value input, gpu::AllReduceOperation mode,
373 const ClusterInfo &ci, amdgpu::Chipset chipset) {
377 constexpr
int allRows = 0xf;
378 constexpr
int allBanks = 0xf;
379 const bool boundCtrl =
true;
380 if (ci.clusterSize >= 2) {
382 dpp = rewriter.
create<amdgpu::DPPOp>(
383 loc, res.
getType(), res, res, amdgpu::DPPPerm::quad_perm,
389 if (ci.clusterSize >= 4) {
391 dpp = rewriter.
create<amdgpu::DPPOp>(
392 loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
397 if (ci.clusterSize >= 8) {
400 dpp = rewriter.
create<amdgpu::DPPOp>(
401 loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
402 rewriter.
getUnitAttr(), allRows, allBanks, boundCtrl);
406 if (ci.clusterSize >= 16) {
409 dpp = rewriter.
create<amdgpu::DPPOp>(
410 loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
411 rewriter.
getUnitAttr(), allRows, allBanks, boundCtrl);
415 if (ci.clusterSize >= 32) {
416 if (chipset.majorVersion <= 9) {
419 dpp = rewriter.
create<amdgpu::DPPOp>(
420 loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15,
425 }
else if (chipset.majorVersion <= 12) {
427 Value uint32Max = rewriter.
create<arith::ConstantOp>(
429 dpp = rewriter.
create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
430 uint32Max, uint32Max,
435 if (ci.subgroupSize == 32) {
439 rewriter.
create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
443 op,
"Subgroup reduce lowering to DPP not currently supported for "
447 if (ci.clusterSize >= 64) {
448 if (chipset.majorVersion <= 9) {
451 dpp = rewriter.
create<amdgpu::DPPOp>(
452 loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
456 }
else if (chipset.majorVersion <= 12) {
464 dpp = rewriter.
create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
465 res = rewriter.
create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
468 op,
"Subgroup reduce lowering to DPP not currently supported for "
474 assert(res.getType() == input.
getType());
481 struct ScalarSubgroupReduceToDPP final
484 bool matchClustered, amdgpu::Chipset chipset,
487 matchClustered(matchClustered), chipset(chipset) {}
489 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
491 if (op.getClusterSize().has_value() != matchClustered) {
493 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
494 "only match {1}clustered ops",
495 matchClustered ?
"non-" :
"",
496 matchClustered ?
"" :
"non-"));
502 if (ci->clusterStride != 1)
504 op,
"Subgroup reductions using DPP are currently only available for "
505 "clusters of contiguous lanes.");
507 Type valueTy = op.getType();
510 op,
"Value type is not a compatible scalar.");
512 FailureOr<Value> dpp = createSubgroupDPPReduction(
513 rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
523 bool matchClustered =
false;
524 amdgpu::Chipset chipset;
532 maxShuffleBitwidth, benefit);
555 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
563 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
constexpr unsigned subgroupSize
HW dependent constants.
IntegerAttr getI32IntegerAttr(int32_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into amdgpu.dpp ops over scalar types.
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToDPPPatterns that only matches gpu....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.