23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
46 struct BreakDownSubgroupReduce final :
OpRewritePattern<gpu::SubgroupReduceOp> {
47 BreakDownSubgroupReduce(
MLIRContext *ctx,
unsigned maxShuffleBitwidth,
54 auto vecTy = dyn_cast<VectorType>(op.getType());
55 if (!vecTy || vecTy.getNumElements() < 2)
58 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
59 assert(!vecTy.isScalable() &&
"Unexpected vector type");
61 Type elemTy = vecTy.getElementType();
63 if (elemBitwidth >= maxShuffleBitwidth)
65 op, llvm::formatv(
"element type too large ({0}), cannot break down "
66 "into vectors of bitwidth {1} or less",
67 elemBitwidth, maxShuffleBitwidth));
69 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
70 assert(elementsPerShuffle >= 1);
72 unsigned numNewReductions =
74 assert(numNewReductions >= 1);
75 if (numNewReductions == 1)
82 for (
unsigned i = 0; i != numNewReductions; ++i) {
83 int64_t startIdx = i * elementsPerShuffle;
85 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
86 int64_t numElems = endIdx - startIdx;
91 rewriter.
create<vector::ExtractOp>(loc, op.getValue(), startIdx);
93 extracted = rewriter.
create<vector::ExtractStridedSliceOp>(
94 loc, op.getValue(), startIdx, numElems,
99 loc, extracted, op.getOp(), op.getUniform());
101 res = rewriter.
create<vector::InsertOp>(loc,
reduce, res, startIdx);
105 res = rewriter.
create<vector::InsertStridedSliceOp>(
106 loc,
reduce, res, startIdx, 1);
114 unsigned maxShuffleBitwidth = 0;
125 struct ScalarizeSingleElementReduce final
131 auto vecTy = dyn_cast<VectorType>(op.getType());
132 if (!vecTy || vecTy.getNumElements() != 1)
135 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
136 assert(!vecTy.isScalable() &&
"Unexpected vector type");
138 Value extracted = rewriter.
create<vector::ExtractOp>(loc, op.getValue(), 0);
140 loc, extracted, op.getOp(), op.getUniform());
152 static Value createSubgroupShuffleReduction(
156 assert(llvm::isPowerOf2_32(subgroupSize));
159 Value laneVal = input;
161 for (
unsigned i = 1; i < subgroupSize; i <<= 1) {
162 Value shuffled = builder
163 .
create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
165 gpu::ShuffleMode::XOR)
169 laneVal, unpackFn(shuffled));
177 struct ScalarSubgroupReduceToShuffles final
179 ScalarSubgroupReduceToShuffles(
MLIRContext *ctx,
unsigned subgroupSize,
180 unsigned shuffleBitwidth,
183 shuffleBitwidth(shuffleBitwidth) {}
187 Type valueTy = op.getType();
188 unsigned elemBitwidth =
190 if (!valueTy.
isIntOrFloat() || elemBitwidth > shuffleBitwidth)
192 op,
"value type is not a compatible scalar");
196 if (elemBitwidth == shuffleBitwidth) {
197 auto identityFn = [](
Value v) {
return v; };
198 rewriter.
replaceOp(op, createSubgroupShuffleReduction(
199 rewriter, loc, op.getValue(), op.getOp(),
200 subgroupSize, identityFn, identityFn));
206 auto packFn = [loc, &rewriter, equivIntType,
209 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
210 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
212 auto unpackFn = [loc, &rewriter, equivIntType,
215 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
216 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
219 rewriter.replaceOp(op, createSubgroupShuffleReduction(
220 rewriter, loc, op.getValue(), op.getOp(),
221 subgroupSize, packFn, unpackFn));
226 unsigned subgroupSize = 0;
227 unsigned shuffleBitwidth = 0;
231 struct VectorSubgroupReduceToShuffles final
233 VectorSubgroupReduceToShuffles(
MLIRContext *ctx,
unsigned subgroupSize,
234 unsigned shuffleBitwidth,
237 shuffleBitwidth(shuffleBitwidth) {}
241 auto vecTy = dyn_cast<VectorType>(op.getType());
245 unsigned vecBitwidth =
246 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
247 if (vecBitwidth > shuffleBitwidth)
250 llvm::formatv(
"vector type bitwidth too large ({0}), cannot lower "
251 "to shuffles of size {1}",
252 vecBitwidth, shuffleBitwidth));
254 unsigned elementsPerShuffle =
255 shuffleBitwidth / vecTy.getElementTypeBitWidth();
256 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
258 op,
"shuffle bitwidth is not a multiple of the element bitwidth");
265 static_cast<int64_t
>(elementsPerShuffle), vecTy.getElementType());
266 Value extendedInput = op.getValue();
267 if (vecBitwidth < shuffleBitwidth) {
268 auto zero = rewriter.
create<arith::ConstantOp>(
270 extendedInput = rewriter.
create<vector::InsertStridedSliceOp>(
271 loc, extendedInput, zero, 0, 1);
277 auto packFn = [loc, &rewriter, shuffleVecType](
Value unpackedVal) ->
Value {
279 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
280 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
282 auto unpackFn = [loc, &rewriter, shuffleVecType,
285 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
286 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
290 createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
291 subgroupSize, packFn, unpackFn);
293 if (vecBitwidth < shuffleBitwidth) {
294 res = rewriter.create<vector::ExtractStridedSliceOp>(
295 loc, res, 0, vecTy.getNumElements(),
299 rewriter.replaceOp(op, res);
304 unsigned subgroupSize = 0;
305 unsigned shuffleBitwidth = 0;
313 maxShuffleBitwidth, benefit);
314 patterns.
add<ScalarizeSingleElementReduce>(patterns.
getContext(), benefit);
320 patterns.
add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
321 patterns.
getContext(), subgroupSize, shuffleBitwidth, benefit);
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePattenrs(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...