22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
45 struct BreakDownSubgroupReduce final :
OpRewritePattern<gpu::SubgroupReduceOp> {
46 BreakDownSubgroupReduce(
MLIRContext *ctx,
unsigned maxShuffleBitwidth,
51 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
53 auto vecTy = dyn_cast<VectorType>(op.getType());
54 if (!vecTy || vecTy.getNumElements() < 2)
57 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
58 assert(!vecTy.isScalable() &&
"Unexpected vector type");
60 Type elemTy = vecTy.getElementType();
62 if (elemBitwidth >= maxShuffleBitwidth)
64 op, llvm::formatv(
"element type too large ({0}), cannot break down "
65 "into vectors of bitwidth {1} or less",
66 elemBitwidth, maxShuffleBitwidth));
68 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
69 assert(elementsPerShuffle >= 1);
71 unsigned numNewReductions =
73 assert(numNewReductions >= 1);
74 if (numNewReductions == 1)
81 for (
unsigned i = 0; i != numNewReductions; ++i) {
82 int64_t startIdx = i * elementsPerShuffle;
84 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
85 int64_t numElems = endIdx - startIdx;
90 rewriter.
create<vector::ExtractOp>(loc, op.getValue(), startIdx);
92 extracted = rewriter.
create<vector::ExtractStridedSliceOp>(
93 loc, op.getValue(), startIdx, numElems,
98 loc, extracted, op.getOp(), op.getUniform());
100 res = rewriter.
create<vector::InsertOp>(loc,
reduce, res, startIdx);
104 res = rewriter.
create<vector::InsertStridedSliceOp>(
105 loc,
reduce, res, startIdx, 1);
113 unsigned maxShuffleBitwidth = 0;
124 struct ScalarizeSingleElementReduce final
128 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
130 auto vecTy = dyn_cast<VectorType>(op.getType());
131 if (!vecTy || vecTy.getNumElements() != 1)
134 assert(vecTy.getRank() == 1 &&
"Unexpected vector type");
135 assert(!vecTy.isScalable() &&
"Unexpected vector type");
137 Value extracted = rewriter.
create<vector::ExtractOp>(loc, op.getValue(), 0);
139 loc, extracted, op.getOp(), op.getUniform());
151 static Value createSubgroupShuffleReduction(
155 assert(llvm::isPowerOf2_32(subgroupSize));
158 Value laneVal = input;
160 for (
unsigned i = 1; i < subgroupSize; i <<= 1) {
161 Value shuffled = builder
162 .
create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
164 gpu::ShuffleMode::XOR)
168 laneVal, unpackFn(shuffled));
176 struct ScalarSubgroupReduceToShuffles final
178 ScalarSubgroupReduceToShuffles(
MLIRContext *ctx,
unsigned subgroupSize,
179 unsigned shuffleBitwidth,
182 shuffleBitwidth(shuffleBitwidth) {}
184 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
186 Type valueTy = op.getType();
187 unsigned elemBitwidth =
189 if (!valueTy.
isIntOrFloat() || elemBitwidth > shuffleBitwidth)
191 op,
"value type is not a compatible scalar");
195 if (elemBitwidth == shuffleBitwidth) {
196 auto identityFn = [](
Value v) {
return v; };
197 rewriter.
replaceOp(op, createSubgroupShuffleReduction(
198 rewriter, loc, op.getValue(), op.getOp(),
199 subgroupSize, identityFn, identityFn));
205 auto packFn = [loc, &rewriter, equivIntType,
208 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
209 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
211 auto unpackFn = [loc, &rewriter, equivIntType,
214 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
215 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
218 rewriter.replaceOp(op, createSubgroupShuffleReduction(
219 rewriter, loc, op.getValue(), op.getOp(),
220 subgroupSize, packFn, unpackFn));
225 unsigned subgroupSize = 0;
226 unsigned shuffleBitwidth = 0;
230 struct VectorSubgroupReduceToShuffles final
232 VectorSubgroupReduceToShuffles(
MLIRContext *ctx,
unsigned subgroupSize,
233 unsigned shuffleBitwidth,
236 shuffleBitwidth(shuffleBitwidth) {}
238 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
240 auto vecTy = dyn_cast<VectorType>(op.getType());
244 unsigned vecBitwidth =
245 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
246 if (vecBitwidth > shuffleBitwidth)
249 llvm::formatv(
"vector type bitwidth too large ({0}), cannot lower "
250 "to shuffles of size {1}",
251 vecBitwidth, shuffleBitwidth));
253 unsigned elementsPerShuffle =
254 shuffleBitwidth / vecTy.getElementTypeBitWidth();
255 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
257 op,
"shuffle bitwidth is not a multiple of the element bitwidth");
264 static_cast<int64_t
>(elementsPerShuffle), vecTy.getElementType());
265 Value extendedInput = op.getValue();
266 if (vecBitwidth < shuffleBitwidth) {
267 auto zero = rewriter.
create<arith::ConstantOp>(
269 extendedInput = rewriter.
create<vector::InsertStridedSliceOp>(
270 loc, extendedInput, zero, 0, 1);
276 auto packFn = [loc, &rewriter, shuffleVecType](
Value unpackedVal) ->
Value {
278 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
279 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
281 auto unpackFn = [loc, &rewriter, shuffleVecType,
284 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
285 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
289 createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
290 subgroupSize, packFn, unpackFn);
292 if (vecBitwidth < shuffleBitwidth) {
293 res = rewriter.create<vector::ExtractStridedSliceOp>(
294 loc, res, 0, vecTy.getNumElements(),
298 rewriter.replaceOp(op, res);
303 unsigned subgroupSize = 0;
304 unsigned shuffleBitwidth = 0;
312 maxShuffleBitwidth, benefit);
313 patterns.
add<ScalarizeSingleElementReduce>(patterns.
getContext(), benefit);
319 patterns.
add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
320 patterns.
getContext(), subgroupSize, shuffleBitwidth, benefit);
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePattenrs(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...