25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/MathExtras.h"
48struct BreakDownSubgroupReduce final :
OpRewritePattern<gpu::SubgroupReduceOp> {
49 BreakDownSubgroupReduce(MLIRContext *ctx,
unsigned maxShuffleBitwidth,
50 PatternBenefit benefit)
51 : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
54 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
55 PatternRewriter &rewriter)
const override {
56 auto vecTy = dyn_cast<VectorType>(op.getType());
57 if (!vecTy || vecTy.getNumElements() < 2)
60 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
61 assert(!vecTy.isScalable() &&
"Unexpected vector type");
63 Type elemTy = vecTy.getElementType();
65 if (elemBitwidth >= maxShuffleBitwidth)
67 op, llvm::formatv(
"element type too large ({0}), cannot break down "
68 "into vectors of bitwidth {1} or less",
69 elemBitwidth, maxShuffleBitwidth));
71 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
72 assert(elementsPerShuffle >= 1);
74 unsigned numNewReductions =
75 llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
76 assert(numNewReductions >= 1);
77 if (numNewReductions == 1)
80 Location loc = op.getLoc();
82 arith::ConstantOp::create(rewriter, loc, rewriter.
getZeroAttr(vecTy));
84 for (
unsigned i = 0; i != numNewReductions; ++i) {
85 int64_t startIdx = i * elementsPerShuffle;
87 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
88 int64_t numElems = endIdx - startIdx;
93 vector::ExtractOp::create(rewriter, loc, op.getValue(), startIdx);
95 extracted = vector::ExtractStridedSliceOp::create(
96 rewriter, loc, op.getValue(), startIdx,
101 Value
reduce = gpu::SubgroupReduceOp::create(
102 rewriter, loc, extracted, op.getOp(), op.getUniform(),
103 op.getClusterSize(), op.getClusterStride());
105 res = vector::InsertOp::create(rewriter, loc,
reduce, res, startIdx);
109 res = vector::InsertStridedSliceOp::create(
110 rewriter, loc,
reduce, res, startIdx, 1);
118 unsigned maxShuffleBitwidth = 0;
129struct ScalarizeSingleElementReduce final
133 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
134 PatternRewriter &rewriter)
const override {
135 auto vecTy = dyn_cast<VectorType>(op.getType());
136 if (!vecTy || vecTy.getNumElements() != 1)
139 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
140 assert(!vecTy.isScalable() &&
"Unexpected vector type");
141 Location loc = op.getLoc();
143 vector::ExtractOp::create(rewriter, loc, op.getValue(), 0);
144 Value
reduce = gpu::SubgroupReduceOp::create(
145 rewriter, loc, extracted, op.getOp(), op.getUniform(),
146 op.getClusterSize(), op.getClusterStride());
153 unsigned clusterStride;
154 unsigned clusterSize;
155 unsigned subgroupSize;
158static FailureOr<ClusterInfo>
159getAndValidateClusterInfo(gpu::SubgroupReduceOp op,
unsigned subgroupSize) {
160 assert(llvm::isPowerOf2_32(subgroupSize));
162 std::optional<uint32_t> clusterSize = op.getClusterSize();
163 assert(!clusterSize ||
164 llvm::isPowerOf2_32(*clusterSize));
165 if (clusterSize && *clusterSize > subgroupSize)
166 return op.emitOpError()
167 <<
"cluster size " << *clusterSize
168 <<
" is greater than subgroup size " << subgroupSize;
169 unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
171 auto clusterStride = op.getClusterStride();
172 assert(llvm::isPowerOf2_32(clusterStride));
173 if (clusterStride >= subgroupSize)
174 return op.emitOpError()
175 <<
"cluster stride " << clusterStride
176 <<
" is not less than subgroup size " << subgroupSize;
178 return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
190 Value input, gpu::AllReduceOperation mode,
191 const ClusterInfo &ci,
196 Value laneVal = input;
198 for (
unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
200 Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i,
202 gpu::ShuffleMode::XOR)
206 laneVal, unpackFn(shuffled));
214struct ScalarSubgroupReduceToShuffles final
216 ScalarSubgroupReduceToShuffles(MLIRContext *ctx,
unsigned subgroupSize,
217 unsigned shuffleBitwidth,
bool matchClustered,
218 PatternBenefit benefit)
219 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
220 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
222 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
223 PatternRewriter &rewriter)
const override {
224 if (op.getClusterSize().has_value() != matchClustered) {
226 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
227 "only match {1}clustered ops",
228 matchClustered ?
"non-" :
"",
229 matchClustered ?
"" :
"non-"));
232 auto ci = getAndValidateClusterInfo(op, subgroupSize);
236 Type valueTy = op.getType();
237 unsigned elemBitwidth =
239 if (!valueTy.
isIntOrFloat() || elemBitwidth > shuffleBitwidth)
241 op,
"value type is not a compatible scalar");
243 Location loc = op.getLoc();
245 if (elemBitwidth == shuffleBitwidth) {
246 auto identityFn = [](Value v) {
return v; };
247 rewriter.
replaceOp(op, createSubgroupShuffleReduction(
248 rewriter, loc, op.getValue(), op.getOp(), *ci,
249 identityFn, identityFn));
255 auto packFn = [loc, &rewriter, equivIntType,
256 shuffleIntType](Value unpackedVal) -> Value {
258 arith::BitcastOp::create(rewriter, loc, equivIntType, unpackedVal);
259 return arith::ExtUIOp::create(rewriter, loc, shuffleIntType, asInt);
261 auto unpackFn = [loc, &rewriter, equivIntType,
262 valueTy](Value packedVal) -> Value {
264 arith::TruncIOp::create(rewriter, loc, equivIntType, packedVal);
265 return arith::BitcastOp::create(rewriter, loc, valueTy, asInt);
269 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270 op.getOp(), *ci, packFn, unpackFn));
275 unsigned subgroupSize = 0;
276 unsigned shuffleBitwidth = 0;
277 bool matchClustered =
false;
281struct VectorSubgroupReduceToShuffles final
283 VectorSubgroupReduceToShuffles(MLIRContext *ctx,
unsigned subgroupSize,
284 unsigned shuffleBitwidth,
bool matchClustered,
285 PatternBenefit benefit)
286 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
287 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
289 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
290 PatternRewriter &rewriter)
const override {
291 if (op.getClusterSize().has_value() != matchClustered) {
293 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
294 "only match {1}clustered ops",
295 matchClustered ?
"non-" :
"",
296 matchClustered ?
"" :
"non-"));
299 auto ci = getAndValidateClusterInfo(op, subgroupSize);
303 auto vecTy = dyn_cast<VectorType>(op.getType());
307 unsigned vecBitwidth =
308 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309 if (vecBitwidth > shuffleBitwidth)
312 llvm::formatv(
"vector type bitwidth too large ({0}), cannot lower "
313 "to shuffles of size {1}",
314 vecBitwidth, shuffleBitwidth));
316 unsigned elementsPerShuffle =
317 shuffleBitwidth / vecTy.getElementTypeBitWidth();
318 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
320 op,
"shuffle bitwidth is not a multiple of the element bitwidth");
322 Location loc = op.getLoc();
326 auto extendedVecTy = VectorType::get(
327 static_cast<int64_t
>(elementsPerShuffle), vecTy.getElementType());
328 Value extendedInput = op.getValue();
329 if (vecBitwidth < shuffleBitwidth) {
330 auto zero = arith::ConstantOp::create(
331 rewriter, loc, rewriter.
getZeroAttr(extendedVecTy));
332 extendedInput = vector::InsertStridedSliceOp::create(
333 rewriter, loc, extendedInput, zero, 0, 1);
337 auto shuffleVecType = VectorType::get(1, shuffleIntType);
339 auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
341 vector::BitCastOp::create(rewriter, loc, shuffleVecType, unpackedVal);
342 return vector::ExtractOp::create(rewriter, loc, asIntVec, 0);
344 auto unpackFn = [loc, &rewriter, shuffleVecType,
345 extendedVecTy](Value packedVal) -> Value {
347 vector::BroadcastOp::create(rewriter, loc, shuffleVecType, packedVal);
348 return vector::BitCastOp::create(rewriter, loc, extendedVecTy, asIntVec);
351 Value res = createSubgroupShuffleReduction(
352 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
354 if (vecBitwidth < shuffleBitwidth) {
355 res = vector::ExtractStridedSliceOp::create(
356 rewriter, loc, res, 0, vecTy.getNumElements(),
360 rewriter.replaceOp(op, res);
365 unsigned subgroupSize = 0;
366 unsigned shuffleBitwidth = 0;
367 bool matchClustered =
false;
370static FailureOr<Value>
371createSubgroupDPPReduction(
PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372 Value input, gpu::AllReduceOperation mode,
377 constexpr int allRows = 0xf;
378 constexpr int allBanks = 0xf;
379 const bool boundCtrl =
true;
380 if (ci.clusterSize >= 2) {
382 dpp = amdgpu::DPPOp::create(
383 rewriter, loc, res.
getType(), res, res, amdgpu::DPPPerm::quad_perm,
389 if (ci.clusterSize >= 4) {
391 dpp = amdgpu::DPPOp::create(
392 rewriter, loc, res.
getType(), res, res, amdgpu::DPPPerm::quad_perm,
397 if (ci.clusterSize >= 8) {
400 dpp = amdgpu::DPPOp::create(rewriter, loc, res.
getType(), res, res,
401 amdgpu::DPPPerm::row_half_mirror,
407 if (ci.clusterSize >= 16) {
410 dpp = amdgpu::DPPOp::create(
411 rewriter, loc, res.
getType(), res, res, amdgpu::DPPPerm::row_mirror,
412 rewriter.
getUnitAttr(), allRows, allBanks, boundCtrl);
416 if (ci.clusterSize >= 32) {
420 dpp = amdgpu::DPPOp::create(rewriter, loc, res.
getType(), res, res,
421 amdgpu::DPPPerm::row_bcast_15,
446 if (ci.subgroupSize == 64 && ci.clusterSize == 32) {
448 amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, 0,
454 Value uint32Max = arith::ConstantOp::create(
456 dpp = ROCDL::PermlaneX16Op::create(rewriter, loc, res.
getType(), res, res,
457 uint32Max, uint32Max,
464 op,
"Subgroup reduce lowering to DPP not currently supported for "
467 if (ci.subgroupSize == 32) {
468 Value lane31 = arith::ConstantOp::create(
471 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane31);
474 if (ci.clusterSize >= 64) {
477 dpp = amdgpu::DPPOp::create(rewriter, loc, res.
getType(), res, res,
478 amdgpu::DPPPerm::row_bcast_31,
484 Value lane63 = arith::ConstantOp::create(
487 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane63);
493 Value lane31 = arith::ConstantOp::create(
495 Value lane63 = arith::ConstantOp::create(
498 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane31);
500 ROCDL::ReadlaneOp::create(rewriter, loc, res.
getType(), res, lane63);
505 op,
"Subgroup reduce lowering to DPP not currently supported for "
516struct ScalarSubgroupReduceToDPP final
518 ScalarSubgroupReduceToDPP(MLIRContext *ctx,
unsigned subgroupSize,
519 bool matchClustered, amdgpu::Chipset chipset,
520 PatternBenefit benefit)
521 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
522 matchClustered(matchClustered), chipset(chipset) {}
524 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
525 PatternRewriter &rewriter)
const override {
526 if (op.getClusterSize().has_value() != matchClustered) {
528 op, llvm::formatv(
"op is {0}clustered but pattern is configured to "
529 "only match {1}clustered ops",
530 matchClustered ?
"non-" :
"",
531 matchClustered ?
"" :
"non-"));
533 auto ci = getAndValidateClusterInfo(op, subgroupSize);
537 if (ci->clusterStride != 1)
539 op,
"Subgroup reductions using DPP are currently only available for "
540 "clusters of contiguous lanes.");
542 Type valueTy = op.getType();
545 op,
"Value type is not a compatible scalar.");
547 FailureOr<Value> dpp = createSubgroupDPPReduction(
548 rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
557 unsigned subgroupSize = 0;
558 bool matchClustered =
false;
559 amdgpu::Chipset chipset;
567 maxShuffleBitwidth, benefit);
574 patterns.add<ScalarSubgroupReduceToDPP>(
patterns.getContext(), subgroupSize,
582 patterns.add<ScalarSubgroupReduceToDPP>(
patterns.getContext(), subgroupSize,
590 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
591 patterns.getContext(), subgroupSize, shuffleBitwidth,
598 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
599 patterns.getContext(), subgroupSize, shuffleBitwidth,
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
IntegerAttr getI32IntegerAttr(int32_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into amdgpu.dpp ops over scalar types.
llvm::function_ref< Fn > function_ref
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToDPPPatterns that only matches gpu....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.