25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/MathExtras.h"
48 struct BreakDownSubgroupReduce final :
OpRewritePattern<gpu::SubgroupReduceOp> {
49 BreakDownSubgroupReduce(
MLIRContext *ctx,
unsigned maxShuffleBitwidth,
54 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
56 auto vecTy = dyn_cast<VectorType>(op.getType());
57 if (!vecTy || vecTy.getNumElements() < 2)
60 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
61 assert(!vecTy.isScalable() &&
"Unexpected vector type");
63 Type elemTy = vecTy.getElementType();
65 if (elemBitwidth >= maxShuffleBitwidth)
67 op, llvm::formatv(
"element type too large ({0}), cannot break down "
68 "into vectors of bitwidth {1} or less",
69 elemBitwidth, maxShuffleBitwidth));
71 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
72 assert(elementsPerShuffle >= 1);
74 unsigned numNewReductions =
76 assert(numNewReductions >= 1);
77 if (numNewReductions == 1)
82 arith::ConstantOp::create(rewriter, loc, rewriter.
getZeroAttr(vecTy));
84 for (
unsigned i = 0; i != numNewReductions; ++i) {
85 int64_t startIdx = i * elementsPerShuffle;
87 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
88 int64_t numElems = endIdx - startIdx;
93 vector::ExtractOp::create(rewriter, loc, op.getValue(), startIdx);
95 extracted = vector::ExtractStridedSliceOp::create(
96 rewriter, loc, op.getValue(), startIdx,
102 rewriter, loc, extracted, op.getOp(), op.getUniform(),
103 op.getClusterSize(), op.getClusterStride());
105 res = vector::InsertOp::create(rewriter, loc,
reduce, res, startIdx);
109 res = vector::InsertStridedSliceOp::create(
110 rewriter, loc,
reduce, res, startIdx, 1);
118 unsigned maxShuffleBitwidth = 0;
129 struct ScalarizeSingleElementReduce final
133 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
135 auto vecTy = dyn_cast<VectorType>(op.getType());
136 if (!vecTy || vecTy.getNumElements() != 1)
139 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
140 assert(!vecTy.isScalable() &&
"Unexpected vector type");
143 vector::ExtractOp::create(rewriter, loc, op.getValue(), 0);
145 rewriter, loc, extracted, op.getOp(), op.getUniform(),
146 op.getClusterSize(), op.getClusterStride());
153 unsigned clusterStride;
154 unsigned clusterSize;
158 static FailureOr<ClusterInfo>
159 getAndValidateClusterInfo(gpu::SubgroupReduceOp op,
unsigned subgroupSize) {
162 std::optional<uint32_t> clusterSize = op.getClusterSize();
163 assert(!clusterSize ||
164 llvm::isPowerOf2_32(*clusterSize));
166 return op.emitOpError()
167 <<
"cluster size " << *clusterSize
169 unsigned effectiveClusterSize = clusterSize.value_or(
subgroupSize);
171 auto clusterStride = op.getClusterStride();
172 assert(llvm::isPowerOf2_32(clusterStride));
174 return op.emitOpError()
175 <<
"cluster stride " << clusterStride
178 return ClusterInfo{clusterStride, effectiveClusterSize,
subgroupSize};
190 Value input, gpu::AllReduceOperation mode,
191 const ClusterInfo &ci,
196 Value laneVal = input;
198 for (
unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
200 Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i,
202 gpu::ShuffleMode::XOR)
206 laneVal, unpackFn(shuffled));
214 struct ScalarSubgroupReduceToShuffles final
217 unsigned shuffleBitwidth,
bool matchClustered,
220 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
222 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
224 if (op.getClusterSize().has_value() != matchClustered) {
226 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
227 "only match {1}clustered ops",
228 matchClustered ?
"non-" :
"",
229 matchClustered ?
"" :
"non-"));
236 Type valueTy = op.getType();
237 unsigned elemBitwidth =
239 if (!valueTy.
isIntOrFloat() || elemBitwidth > shuffleBitwidth)
241 op,
"value type is not a compatible scalar");
245 if (elemBitwidth == shuffleBitwidth) {
246 auto identityFn = [](
Value v) {
return v; };
247 rewriter.
replaceOp(op, createSubgroupShuffleReduction(
248 rewriter, loc, op.getValue(), op.getOp(), *ci,
249 identityFn, identityFn));
255 auto packFn = [loc, &rewriter, equivIntType,
258 arith::BitcastOp::create(rewriter, loc, equivIntType, unpackedVal);
259 return arith::ExtUIOp::create(rewriter, loc, shuffleIntType, asInt);
261 auto unpackFn = [loc, &rewriter, equivIntType,
264 arith::TruncIOp::create(rewriter, loc, equivIntType, packedVal);
265 return arith::BitcastOp::create(rewriter, loc, valueTy, asInt);
269 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270 op.getOp(), *ci, packFn, unpackFn));
276 unsigned shuffleBitwidth = 0;
277 bool matchClustered =
false;
281 struct VectorSubgroupReduceToShuffles final
284 unsigned shuffleBitwidth,
bool matchClustered,
287 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
289 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
291 if (op.getClusterSize().has_value() != matchClustered) {
293 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
294 "only match {1}clustered ops",
295 matchClustered ?
"non-" :
"",
296 matchClustered ?
"" :
"non-"));
303 auto vecTy = dyn_cast<VectorType>(op.getType());
307 unsigned vecBitwidth =
308 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309 if (vecBitwidth > shuffleBitwidth)
312 llvm::formatv(
"vector type bitwidth too large ({0}), cannot lower "
313 "to shuffles of size {1}",
314 vecBitwidth, shuffleBitwidth));
316 unsigned elementsPerShuffle =
317 shuffleBitwidth / vecTy.getElementTypeBitWidth();
318 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
320 op,
"shuffle bitwidth is not a multiple of the element bitwidth");
327 static_cast<int64_t
>(elementsPerShuffle), vecTy.getElementType());
328 Value extendedInput = op.getValue();
329 if (vecBitwidth < shuffleBitwidth) {
330 auto zero = arith::ConstantOp::create(
331 rewriter, loc, rewriter.
getZeroAttr(extendedVecTy));
332 extendedInput = vector::InsertStridedSliceOp::create(
333 rewriter, loc, extendedInput, zero, 0, 1);
339 auto packFn = [loc, &rewriter, shuffleVecType](
Value unpackedVal) ->
Value {
341 vector::BitCastOp::create(rewriter, loc, shuffleVecType, unpackedVal);
342 return vector::ExtractOp::create(rewriter, loc, asIntVec, 0);
344 auto unpackFn = [loc, &rewriter, shuffleVecType,
347 vector::BroadcastOp::create(rewriter, loc, shuffleVecType, packedVal);
348 return vector::BitCastOp::create(rewriter, loc, extendedVecTy, asIntVec);
351 Value res = createSubgroupShuffleReduction(
352 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
354 if (vecBitwidth < shuffleBitwidth) {
355 res = vector::ExtractStridedSliceOp::create(
356 rewriter, loc, res, 0, vecTy.getNumElements(),
360 rewriter.replaceOp(op, res);
366 unsigned shuffleBitwidth = 0;
367 bool matchClustered =
false;
370 static FailureOr<Value>
371 createSubgroupDPPReduction(
PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372 Value input, gpu::AllReduceOperation mode,
373 const ClusterInfo &ci, amdgpu::Chipset chipset) {
377 constexpr
int allRows = 0xf;
378 constexpr
int allBanks = 0xf;
379 const bool boundCtrl =
true;
380 if (ci.clusterSize >= 2) {
382 dpp = amdgpu::DPPOp::create(
383 rewriter, loc, res.
getType(), res, res, amdgpu::DPPPerm::quad_perm,
389 if (ci.clusterSize >= 4) {
391 dpp = amdgpu::DPPOp::create(
392 rewriter, loc, res.
getType(), res, res, amdgpu::DPPPerm::quad_perm,
397 if (ci.clusterSize >= 8) {
400 dpp = amdgpu::DPPOp::create(rewriter, loc, res.
getType(), res, res,
401 amdgpu::DPPPerm::row_half_mirror,
407 if (ci.clusterSize >= 16) {
410 dpp = amdgpu::DPPOp::create(
411 rewriter, loc, res.
getType(), res, res, amdgpu::DPPPerm::row_mirror,
412 rewriter.
getUnitAttr(), allRows, allBanks, boundCtrl);
416 if (ci.clusterSize >= 32) {
417 if (chipset.majorVersion <= 9) {
420 dpp = amdgpu::DPPOp::create(rewriter, loc, res.
getType(), res, res,
421 amdgpu::DPPPerm::row_bcast_15,
426 }
else if (chipset.majorVersion <= 12) {
428 Value uint32Max = arith::ConstantOp::create(
430 dpp = ROCDL::PermlaneX16Op::create(rewriter, loc, res.
getType(), res, res,
431 uint32Max, uint32Max,
438 op,
"Subgroup reduce lowering to DPP not currently supported for "
441 if (ci.subgroupSize == 32) {
442 Value lane31 = arith::ConstantOp::create(
445 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane31);
448 if (ci.clusterSize >= 64) {
449 if (chipset.majorVersion <= 9) {
451 dpp = amdgpu::DPPOp::create(rewriter, loc, res.
getType(), res, res,
452 amdgpu::DPPPerm::row_bcast_31,
458 Value lane63 = arith::ConstantOp::create(
461 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane63);
463 }
else if (chipset.majorVersion <= 12) {
467 Value lane31 = arith::ConstantOp::create(
469 Value lane63 = arith::ConstantOp::create(
472 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane31);
474 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane63);
479 op,
"Subgroup reduce lowering to DPP not currently supported for "
490 struct ScalarSubgroupReduceToDPP final
493 bool matchClustered, amdgpu::Chipset chipset,
496 matchClustered(matchClustered), chipset(chipset) {}
498 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
500 if (op.getClusterSize().has_value() != matchClustered) {
502 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
503 "only match {1}clustered ops",
504 matchClustered ?
"non-" :
"",
505 matchClustered ?
"" :
"non-"));
511 if (ci->clusterStride != 1)
513 op,
"Subgroup reductions using DPP are currently only available for "
514 "clusters of contiguous lanes.");
516 Type valueTy = op.getType();
519 op,
"Value type is not a compatible scalar.");
521 FailureOr<Value> dpp = createSubgroupDPPReduction(
522 rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
532 bool matchClustered =
false;
533 amdgpu::Chipset chipset;
541 maxShuffleBitwidth, benefit);
564 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
572 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
constexpr unsigned subgroupSize
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into amdgpu.dpp ops over scalar types.
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToDPPPatterns that only matches gpu....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.