28struct GpuAllReduceRewriter {
29 using AccumulatorFactory = std::function<Value(Value, Value)>;
31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32 PatternRewriter &rewriter)
33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().
getType()),
66 rewriter.setInsertionPoint(reduceOp);
69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
84 create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88 create<arith::ConstantIntOp>(int32Type, 0));
90 Value numThreadsWithSmallerSubgroupId =
91 create<arith::SubIOp>(invocationIdx, laneId);
96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
99 AccumulatorFactory accumFactory = getFactory();
100 assert(accumFactory &&
"failed to create accumulator factory");
103 Value subgroupReduce = createSubgroupReduce(
104 activeWidth, laneId, reduceOp.getValue(), accumFactory);
107 Value buffer = createWorkgroupBuffer();
111 createPredicatedBlock(isFirstLane, [&] {
112 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114 create<memref::StoreOp>(subgroupReduce, buffer, index);
116 create<gpu::BarrierOp>();
119 Value biasedBlockSize =
120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123 invocationIdx, numSubgroups);
128 Value zero = create<arith::ConstantIndexOp>(0);
129 createPredicatedBlock(isValidSubgroup, [&] {
130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131 Value value = create<memref::LoadOp>(valueType, buffer, index);
133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134 create<memref::StoreOp>(
result, buffer, zero);
138 create<gpu::BarrierOp>();
139 Value
result = create<memref::LoadOp>(valueType, buffer, zero);
141 rewriter.replaceOp(reduceOp,
result);
146 template <
typename T,
typename... Args>
147 T create(Args... args) {
148 return T::create(rewriter, loc, std::forward<Args>(args)...);
152 template <
typename T>
153 Value getDimOp(gpu::Dimension dimension) {
154 Value dim = create<T>(indexType, dimension);
155 return create<arith::IndexCastOp>(int32Type, dim);
159 Value createWorkgroupBuffer() {
161 auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
162 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
163 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
164 workgroupMemoryAddressSpace);
165 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
170 AccumulatorFactory getFactory() {
171 auto &body = reduceOp.getBody();
173 return getFactory(body);
174 auto opAttr = reduceOp.getOp();
176 return getFactory(*opAttr);
177 return AccumulatorFactory();
183 AccumulatorFactory getFactory(Region &body) {
184 return [&body,
this](Value
lhs, Value
rhs) -> Value {
185 Block *block = rewriter.getInsertionBlock();
186 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
192 rewriter.cloneRegionBefore(body, *split->
getParent(),
193 split->getIterator(), mapping);
196 block = block->getNextNode();
200 for (; block != split; block = block->getNextNode()) {
202 if (!isa<gpu::YieldOp>(terminator))
204 rewriter.setInsertionPointToEnd(block);
205 rewriter.replaceOpWithNewOp<cf::BranchOp>(
206 terminator, split,
ValueRange(terminator->getOperand(0)));
210 rewriter.setInsertionPointToStart(split);
216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217 return [opName,
this](Value
lhs, Value
rhs) {
235 template <
typename ThenOpsFactory,
typename ElseOpsFactory>
236 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
237 ElseOpsFactory &&elseOpsFactory) {
238 Block *currentBlock = rewriter.getInsertionBlock();
239 auto currentPoint = rewriter.getInsertionPoint();
241 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
242 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->
begin());
243 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->
begin());
245 rewriter.setInsertionPointToEnd(currentBlock);
246 create<cf::CondBranchOp>(condition, thenBlock,
247 ArrayRef<Value>(), elseBlock,
250 rewriter.setInsertionPointToStart(thenBlock);
251 auto thenOperands = thenOpsFactory();
252 create<cf::BranchOp>(continueBlock, thenOperands);
254 rewriter.setInsertionPointToStart(elseBlock);
255 auto elseOperands = elseOpsFactory();
256 create<cf::BranchOp>(continueBlock, elseOperands);
258 assert(thenOperands.size() == elseOperands.size());
259 rewriter.setInsertionPointToStart(continueBlock);
260 for (
auto operand : thenOperands)
261 continueBlock->
addArgument(operand.getType(), operand.getLoc());
265 template <
typename Factory>
266 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
267 static_assert(std::is_same<
decltype(predicatedOpsFactory()),
void>::value,
268 "predicatedOpsFactory should not return any value");
272 predicatedOpsFactory();
273 return ArrayRef<Value>();
275 [&] {
return ArrayRef<Value>(); });
281 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
282 AccumulatorFactory &accumFactory) {
283 Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
284 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
285 activeWidth, subgroupSize);
286 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
292 Value value = operand;
296 for (
int i = 1; i < kSubgroupSize; i <<= 1) {
297 Value offset = create<arith::ConstantIntOp>(int32Type, i);
298 auto shuffleOp = create<gpu::ShuffleOp>(
299 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
303 shuffleOp.getResult(1),
305 return SmallVector<Value, 1>{
306 accumFactory(value, shuffleOp.getResult(0))};
308 [&] {
return llvm::ArrayRef(value); });
309 value = rewriter.getInsertionBlock()->getArgument(0);
311 return SmallVector<Value, 1>{value};
317 Value value = operand;
318 for (
int i = 1; i < kSubgroupSize; i <<= 1) {
319 Value offset = create<arith::ConstantIntOp>(int32Type, i);
321 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
322 gpu::ShuffleMode::XOR);
323 value = accumFactory(value, shuffleOp.getResult(0));
325 return SmallVector<Value, 1>{value};
327 return rewriter.getInsertionBlock()->getArgument(0);
332 Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
333 return create<arith::DivSIOp>(int32Type, value, subgroupSize);
336 gpu::GPUFuncOp funcOp;
337 gpu::AllReduceOp reduceOp;
343 IntegerType int32Type;
345 static constexpr int kSubgroupSize = 32;
349 explicit GpuAllReduceRewrite(MLIRContext *context)
350 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter)
const override {
354 auto funcOp = cast<gpu::GPUFuncOp>(op);
356 SmallVector<gpu::AllReduceOp> reduceOps;
357 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
358 if (!reduceOp.getUniform())
361 reduceOps.emplace_back(reduceOp);
365 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
367 op,
"Non uniform reductions are not supported yet.");
369 for (gpu::AllReduceOp reduceOp : reduceOps)
370 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
BlockArgument getArgument(unsigned i)
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static WalkResult interrupt()
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...