24 #include "llvm/Support/ErrorHandling.h"
30 struct GpuAllReduceRewriter {
33 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
35 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
36 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().
getType()),
68 rewriter.setInsertionPoint(reduceOp);
71 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
72 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
73 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
74 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
75 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
76 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
77 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
78 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
79 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
80 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
81 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
82 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
86 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
87 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
89 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
90 create<arith::ConstantIntOp>(0, int32Type));
92 Value numThreadsWithSmallerSubgroupId =
93 create<arith::SubIOp>(invocationIdx, laneId);
98 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
101 AccumulatorFactory accumFactory = getFactory();
102 assert(accumFactory &&
"failed to create accumulator factory");
105 Value subgroupReduce = createSubgroupReduce(
106 activeWidth, laneId, reduceOp.getValue(), accumFactory);
109 Value buffer = createWorkgroupBuffer();
113 createPredicatedBlock(isFirstLane, [&] {
114 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
115 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
116 create<memref::StoreOp>(subgroupReduce, buffer, index);
118 create<gpu::BarrierOp>();
121 Value biasedBlockSize =
122 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
123 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
124 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
125 invocationIdx, numSubgroups);
130 Value zero = create<arith::ConstantIndexOp>(0);
131 createPredicatedBlock(isValidSubgroup, [&] {
132 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
133 Value value = create<memref::LoadOp>(valueType, buffer, index);
135 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
136 create<memref::StoreOp>(result, buffer, zero);
140 create<gpu::BarrierOp>();
141 Value result = create<memref::LoadOp>(valueType, buffer, zero);
143 rewriter.replaceOp(reduceOp, result);
148 template <
typename T,
typename... Args>
149 T create(Args... args) {
150 return rewriter.create<T>(loc, std::forward<Args>(args)...);
154 template <
typename T>
155 Value getDimOp(gpu::Dimension dimension) {
156 Value dim = create<T>(indexType, dimension);
157 return create<arith::IndexCastOp>(int32Type, dim);
161 Value createWorkgroupBuffer() {
164 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
166 workgroupMemoryAddressSpace);
167 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
172 AccumulatorFactory getFactory() {
173 auto &body = reduceOp.getBody();
175 return getFactory(body);
176 auto opAttr = reduceOp.getOp();
178 return getFactory(*opAttr);
179 return AccumulatorFactory();
185 AccumulatorFactory getFactory(
Region &body) {
187 Block *block = rewriter.getInsertionBlock();
188 Block *split = rewriter.
splitBlock(block, rewriter.getInsertionPoint());
194 rewriter.cloneRegionBefore(body, *split->
getParent(),
195 split->getIterator(), mapping);
198 block = block->getNextNode();
202 for (; block != split; block = block->getNextNode()) {
204 if (!isa<gpu::YieldOp>(terminator))
206 rewriter.setInsertionPointToEnd(block);
207 rewriter.replaceOpWithNewOp<cf::BranchOp>(
208 terminator, split,
ValueRange(terminator->getOperand(0)));
212 rewriter.setInsertionPointToStart(split);
213 return split->
addArgument(lhs.getType(), lhs.getLoc());
218 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
237 template <
typename ThenOpsFactory,
typename ElseOpsFactory>
238 void createIf(
Value condition, ThenOpsFactory &&thenOpsFactory,
239 ElseOpsFactory &&elseOpsFactory) {
240 Block *currentBlock = rewriter.getInsertionBlock();
241 auto currentPoint = rewriter.getInsertionPoint();
247 rewriter.setInsertionPointToEnd(currentBlock);
248 create<cf::CondBranchOp>(condition, thenBlock,
252 rewriter.setInsertionPointToStart(thenBlock);
253 auto thenOperands = thenOpsFactory();
254 create<cf::BranchOp>(continueBlock, thenOperands);
256 rewriter.setInsertionPointToStart(elseBlock);
257 auto elseOperands = elseOpsFactory();
258 create<cf::BranchOp>(continueBlock, elseOperands);
260 assert(thenOperands.size() == elseOperands.size());
261 rewriter.setInsertionPointToStart(continueBlock);
262 for (
auto operand : thenOperands)
263 continueBlock->
addArgument(operand.getType(), operand.getLoc());
267 template <
typename Factory>
268 void createPredicatedBlock(
Value condition, Factory &&predicatedOpsFactory) {
269 static_assert(std::is_same<decltype(predicatedOpsFactory()),
void>::value,
270 "predicatedOpsFactory should not return any value");
274 predicatedOpsFactory();
284 AccumulatorFactory &accumFactory) {
285 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
286 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
287 activeWidth, subgroupSize);
288 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
294 Value value = operand;
298 for (
int i = 1; i < kSubgroupSize; i <<= 1) {
299 Value offset = create<arith::ConstantIntOp>(i, int32Type);
300 auto shuffleOp = create<gpu::ShuffleOp>(
301 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
305 shuffleOp.getResult(1),
307 return SmallVector<Value, 1>{
308 accumFactory(value, shuffleOp.getResult(0))};
311 value = rewriter.getInsertionBlock()->getArgument(0);
319 Value value = operand;
320 for (
int i = 1; i < kSubgroupSize; i <<= 1) {
321 Value offset = create<arith::ConstantIntOp>(i, int32Type);
323 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
324 gpu::ShuffleMode::XOR);
325 value = accumFactory(value, shuffleOp.getResult(0));
329 return rewriter.getInsertionBlock()->getArgument(0);
334 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
335 return create<arith::DivSIOp>(int32Type, value, subgroupSize);
338 gpu::GPUFuncOp funcOp;
339 gpu::AllReduceOp reduceOp;
345 IntegerType int32Type;
347 static constexpr
int kSubgroupSize = 32;
352 :
RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
354 LogicalResult matchAndRewrite(
Operation *op,
356 auto funcOp = cast<gpu::GPUFuncOp>(op);
359 auto callback = [&](gpu::AllReduceOp reduceOp) ->
WalkResult {
360 if (!reduceOp.getUniform())
363 reduceOps.emplace_back(reduceOp);
367 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
369 op,
"Non uniform reductions are not supported yet.");
371 for (gpu::AllReduceOp reduceOp : reduceOps)
372 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
static MLIRContext * getContext(OpFoldResult val)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Block represents an ordered list of Operations.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument getArgument(unsigned i)
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...