28 struct GpuAllReduceRewriter {
31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().
getType()),
66 rewriter.setInsertionPoint(reduceOp);
69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
84 create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88 create<arith::ConstantIntOp>(int32Type, 0));
90 Value numThreadsWithSmallerSubgroupId =
91 create<arith::SubIOp>(invocationIdx, laneId);
96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
99 AccumulatorFactory accumFactory = getFactory();
100 assert(accumFactory &&
"failed to create accumulator factory");
103 Value subgroupReduce = createSubgroupReduce(
104 activeWidth, laneId, reduceOp.getValue(), accumFactory);
107 Value buffer = createWorkgroupBuffer();
111 createPredicatedBlock(isFirstLane, [&] {
112 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114 create<memref::StoreOp>(subgroupReduce, buffer, index);
116 create<gpu::BarrierOp>();
119 Value biasedBlockSize =
120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123 invocationIdx, numSubgroups);
128 Value zero = create<arith::ConstantIndexOp>(0);
129 createPredicatedBlock(isValidSubgroup, [&] {
130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131 Value value = create<memref::LoadOp>(valueType, buffer, index);
133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134 create<memref::StoreOp>(result, buffer, zero);
138 create<gpu::BarrierOp>();
139 Value result = create<memref::LoadOp>(valueType, buffer, zero);
141 rewriter.replaceOp(reduceOp, result);
146 template <
typename T,
typename... Args>
147 T create(Args... args) {
148 return T::create(rewriter, loc, std::forward<Args>(args)...);
152 template <
typename T>
153 Value getDimOp(gpu::Dimension dimension) {
154 Value dim = create<T>(indexType, dimension);
155 return create<arith::IndexCastOp>(int32Type, dim);
159 Value createWorkgroupBuffer() {
162 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
164 workgroupMemoryAddressSpace);
165 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
170 AccumulatorFactory getFactory() {
171 auto &body = reduceOp.getBody();
173 return getFactory(body);
174 auto opAttr = reduceOp.getOp();
176 return getFactory(*opAttr);
177 return AccumulatorFactory();
183 AccumulatorFactory getFactory(
Region &body) {
185 Block *block = rewriter.getInsertionBlock();
186 Block *split = rewriter.
splitBlock(block, rewriter.getInsertionPoint());
192 rewriter.cloneRegionBefore(body, *split->
getParent(),
193 split->getIterator(), mapping);
196 block = block->getNextNode();
200 for (; block != split; block = block->getNextNode()) {
202 if (!isa<gpu::YieldOp>(terminator))
204 rewriter.setInsertionPointToEnd(block);
205 rewriter.replaceOpWithNewOp<cf::BranchOp>(
206 terminator, split,
ValueRange(terminator->getOperand(0)));
210 rewriter.setInsertionPointToStart(split);
211 return split->
addArgument(lhs.getType(), lhs.getLoc());
216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
235 template <
typename ThenOpsFactory,
typename ElseOpsFactory>
236 void createIf(
Value condition, ThenOpsFactory &&thenOpsFactory,
237 ElseOpsFactory &&elseOpsFactory) {
238 Block *currentBlock = rewriter.getInsertionBlock();
239 auto currentPoint = rewriter.getInsertionPoint();
245 rewriter.setInsertionPointToEnd(currentBlock);
246 create<cf::CondBranchOp>(condition, thenBlock,
250 rewriter.setInsertionPointToStart(thenBlock);
251 auto thenOperands = thenOpsFactory();
252 create<cf::BranchOp>(continueBlock, thenOperands);
254 rewriter.setInsertionPointToStart(elseBlock);
255 auto elseOperands = elseOpsFactory();
256 create<cf::BranchOp>(continueBlock, elseOperands);
258 assert(thenOperands.size() == elseOperands.size());
259 rewriter.setInsertionPointToStart(continueBlock);
260 for (
auto operand : thenOperands)
261 continueBlock->
addArgument(operand.getType(), operand.getLoc());
265 template <
typename Factory>
266 void createPredicatedBlock(
Value condition, Factory &&predicatedOpsFactory) {
267 static_assert(std::is_same<decltype(predicatedOpsFactory()),
void>::value,
268 "predicatedOpsFactory should not return any value");
272 predicatedOpsFactory();
282 AccumulatorFactory &accumFactory) {
284 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
286 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
292 Value value = operand;
296 for (
int i = 1; i < kSubgroupSize; i <<= 1) {
297 Value offset = create<arith::ConstantIntOp>(int32Type, i);
298 auto shuffleOp = create<gpu::ShuffleOp>(
299 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
303 shuffleOp.getResult(1),
305 return SmallVector<Value, 1>{
306 accumFactory(value, shuffleOp.getResult(0))};
309 value = rewriter.getInsertionBlock()->getArgument(0);
317 Value value = operand;
318 for (
int i = 1; i < kSubgroupSize; i <<= 1) {
319 Value offset = create<arith::ConstantIntOp>(int32Type, i);
321 create<gpu::ShuffleOp>(shuffleType, value, offset,
subgroupSize,
322 gpu::ShuffleMode::XOR);
323 value = accumFactory(value, shuffleOp.getResult(0));
327 return rewriter.getInsertionBlock()->getArgument(0);
333 return create<arith::DivSIOp>(int32Type, value,
subgroupSize);
336 gpu::GPUFuncOp funcOp;
337 gpu::AllReduceOp reduceOp;
343 IntegerType int32Type;
345 static constexpr
int kSubgroupSize = 32;
350 :
RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
352 LogicalResult matchAndRewrite(
Operation *op,
354 auto funcOp = cast<gpu::GPUFuncOp>(op);
357 auto callback = [&](gpu::AllReduceOp reduceOp) ->
WalkResult {
358 if (!reduceOp.getUniform())
361 reduceOps.emplace_back(reduceOp);
365 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
367 op,
"Non uniform reductions are not supported yet.");
369 for (gpu::AllReduceOp reduceOp : reduceOps)
370 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
static MLIRContext * getContext(OpFoldResult val)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Block represents an ordered list of Operations.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgument getArgument(unsigned i)
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
constexpr unsigned subgroupSize
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...