MLIR 23.0.0git
AllReduceLowering.cpp
Go to the documentation of this file.
1//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect lowering of the all-reduce op to a block of
10// simpler instructions.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
23
24using namespace mlir;
25
26namespace {
27
28struct GpuAllReduceRewriter {
29 using AccumulatorFactory = std::function<Value(Value, Value)>;
30
31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32 PatternRewriter &rewriter)
33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
35 indexType(IndexType::get(reduceOp.getContext())),
36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
37
38 /// Creates an all_reduce across the workgroup.
39 ///
40 /// First reduce the elements within a subgroup. The first invocation of each
41 /// subgroup writes the intermediate result to workgroup memory. After
42 /// synchronizing the workgroup, the first subgroup reduces the values from
43 /// workgroup memory. The result is broadcasted to all invocations through
44 /// workgroup memory.
45 ///
46 /// %subgroup_reduce = `createSubgroupReduce(%operand)`
47 /// cf.cond_br %is_first_lane, ^then1, ^continue1
48 /// ^then1:
49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50 /// cf.br ^continue1
51 /// ^continue1:
52 /// gpu.barrier
53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
55 /// ^then2:
56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
58 /// store %all_reduce, %workgroup_buffer[%zero]
59 /// llvm.br ^continue2
60 /// ^continue2:
61 /// gpu.barrier
62 /// %result = load %workgroup_buffer[%zero]
63 /// return %result
64 ///
65 void rewrite() {
66 rewriter.setInsertionPoint(reduceOp);
67
68 // Compute linear invocation index and workgroup size.
69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
81
82 // Compute lane id (invocation id withing the subgroup).
83 Value subgroupMask =
84 create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
86 Value isFirstLane =
87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88 create<arith::ConstantIntOp>(int32Type, 0));
89
90 Value numThreadsWithSmallerSubgroupId =
91 create<arith::SubIOp>(invocationIdx, laneId);
92 // The number of active invocations starting from the current subgroup.
93 // The consumers do not require the value to be clamped to the size of the
94 // subgroup.
95 Value activeWidth =
96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
97
98 // Create factory for op which accumulates to values.
99 AccumulatorFactory accumFactory = getFactory();
100 assert(accumFactory && "failed to create accumulator factory");
101
102 // Reduce elements within each subgroup to produce the intermediate results.
103 Value subgroupReduce = createSubgroupReduce(
104 activeWidth, laneId, reduceOp.getValue(), accumFactory);
105
106 // Add workgroup buffer to parent function for intermediate result.
107 Value buffer = createWorkgroupBuffer();
108
109 // Write the intermediate results to workgroup memory, using the first lane
110 // of each subgroup.
111 createPredicatedBlock(isFirstLane, [&] {
112 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114 create<memref::StoreOp>(subgroupReduce, buffer, index);
115 });
116 create<gpu::BarrierOp>(buffer);
117
118 // Compute number of active subgroups.
119 Value biasedBlockSize =
120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123 invocationIdx, numSubgroups);
124
125 // Use the first numSubgroups invocations to reduce the intermediate results
126 // from workgroup memory. The final result is written to workgroup memory
127 // again.
128 Value zero = create<arith::ConstantIndexOp>(0);
129 createPredicatedBlock(isValidSubgroup, [&] {
130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131 Value value = create<memref::LoadOp>(valueType, buffer, index);
132 Value result =
133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134 create<memref::StoreOp>(result, buffer, zero);
135 });
136
137 // Synchronize workgroup and load result from workgroup memory.
138 create<gpu::BarrierOp>(buffer);
139 Value result = create<memref::LoadOp>(valueType, buffer, zero);
140
141 rewriter.replaceOp(reduceOp, result);
142 }
143
144private:
145 // Shortcut to create an op from rewriter using loc as the first argument.
146 template <typename T, typename... Args>
147 T create(Args... args) {
148 return T::create(rewriter, loc, std::forward<Args>(args)...);
149 }
150
151 // Creates dimension op of type T, with the result casted to int32.
152 template <typename T>
153 Value getDimOp(gpu::Dimension dimension) {
154 Value dim = create<T>(indexType, dimension);
155 return create<arith::IndexCastOp>(int32Type, dim);
156 }
157
158 /// Adds type to funcOp's workgroup attributions.
159 Value createWorkgroupBuffer() {
160 // TODO: Pick a proper location for the attribution.
161 auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
162 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
163 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
164 workgroupMemoryAddressSpace);
165 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
166 }
167
168 /// Returns an accumulator factory using either the op attribute or the body
169 /// region.
170 AccumulatorFactory getFactory() {
171 auto &body = reduceOp.getBody();
172 if (!body.empty())
173 return getFactory(body);
174 auto opAttr = reduceOp.getOp();
175 if (opAttr)
176 return getFactory(*opAttr);
177 return AccumulatorFactory();
178 }
179
180 /// Returns an accumulator factory that clones the body. The body's entry
181 /// block is expected to have 2 arguments. The gpu.yield return the
182 /// accumulated value of the same type.
183 AccumulatorFactory getFactory(Region &body) {
184 return [&body, this](Value lhs, Value rhs) -> Value {
185 Block *block = rewriter.getInsertionBlock();
186 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
187
188 // Insert accumulator body between split block.
189 IRMapping mapping;
190 mapping.map(body.getArgument(0), lhs);
191 mapping.map(body.getArgument(1), rhs);
192 rewriter.cloneRegionBefore(body, *split->getParent(),
193 split->getIterator(), mapping);
194
195 // Add branch before inserted body, into body.
196 block = block->getNextNode();
197 create<cf::BranchOp>(block, ValueRange());
198
199 // Replace all gpu.yield ops with branch out of body.
200 for (; block != split; block = block->getNextNode()) {
201 Operation *terminator = block->getTerminator();
202 if (!isa<gpu::YieldOp>(terminator))
203 continue;
204 rewriter.setInsertionPointToEnd(block);
205 rewriter.replaceOpWithNewOp<cf::BranchOp>(
206 terminator, split, ValueRange(terminator->getOperand(0)));
207 }
208
209 // Return accumulator result.
210 rewriter.setInsertionPointToStart(split);
211 return split->addArgument(lhs.getType(), lhs.getLoc());
212 };
213 }
214
215 /// Returns an accumulator factory that creates an op specified by opName.
216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217 return [opName, this](Value lhs, Value rhs) {
218 return vector::makeArithReduction(rewriter, loc,
219 convertReductionKind(opName), lhs, rhs);
220 };
221 }
222
223 /// Creates an if-block skeleton and calls the two factories to generate the
224 /// ops in the `then` and `else` block..
225 ///
226 /// llvm.cond_br %condition, ^then, ^continue
227 /// ^then:
228 /// %then_operands = `thenOpsFactory()`
229 /// llvm.br ^continue(%then_operands)
230 /// ^else:
231 /// %else_operands = `elseOpsFactory()`
232 /// llvm.br ^continue(%else_operands)
233 /// ^continue(%block_operands):
234 ///
235 template <typename ThenOpsFactory, typename ElseOpsFactory>
236 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
237 ElseOpsFactory &&elseOpsFactory) {
238 Block *currentBlock = rewriter.getInsertionBlock();
239 auto currentPoint = rewriter.getInsertionPoint();
240
241 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
242 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
243 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
244
245 rewriter.setInsertionPointToEnd(currentBlock);
246 create<cf::CondBranchOp>(condition, thenBlock,
247 /*trueOperands=*/ArrayRef<Value>(), elseBlock,
248 /*falseOperands=*/ArrayRef<Value>());
249
250 rewriter.setInsertionPointToStart(thenBlock);
251 auto thenOperands = thenOpsFactory();
252 create<cf::BranchOp>(continueBlock, thenOperands);
253
254 rewriter.setInsertionPointToStart(elseBlock);
255 auto elseOperands = elseOpsFactory();
256 create<cf::BranchOp>(continueBlock, elseOperands);
257
258 assert(thenOperands.size() == elseOperands.size());
259 rewriter.setInsertionPointToStart(continueBlock);
260 for (auto operand : thenOperands)
261 continueBlock->addArgument(operand.getType(), operand.getLoc());
262 }
263
264 /// Shortcut for createIf with empty else block and no block operands.
265 template <typename Factory>
266 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
267 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
268 "predicatedOpsFactory should not return any value");
269 createIf(
270 condition,
271 [&] {
272 predicatedOpsFactory();
273 return ArrayRef<Value>();
274 },
275 [&] { return ArrayRef<Value>(); });
276 }
277
278 /// Creates a reduction across the first activeWidth lanes of a subgroup, or
279 /// the entire subgroup if activeWidth is larger than the subgroup width.
280 /// The first lane returns the result, all others return values are undefined.
281 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
282 AccumulatorFactory &accumFactory) {
283 Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
284 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
285 activeWidth, subgroupSize);
286 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
287
288 createIf(
289 isPartialSubgroup,
290 // Generate reduction over a (potentially) partial subgroup.
291 [&] {
292 Value value = operand;
293 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
294 // lane is within the active range. The accumulated value is available
295 // in the first lane.
296 for (int i = 1; i < kSubgroupSize; i <<= 1) {
297 Value offset = create<arith::ConstantIntOp>(int32Type, i);
298 auto shuffleOp = create<gpu::ShuffleOp>(
299 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
300 // Skip the accumulation if the shuffle op read from a lane outside
301 // of the active range.
302 createIf(
303 shuffleOp.getResult(1),
304 [&] {
305 return SmallVector<Value, 1>{
306 accumFactory(value, shuffleOp.getResult(0))};
307 },
308 [&] { return llvm::ArrayRef(value); });
309 value = rewriter.getInsertionBlock()->getArgument(0);
310 }
311 return SmallVector<Value, 1>{value};
312 },
313 // Generate a reduction over the entire subgroup. This is a
314 // specialization of the above reduction with unconditional
315 // accumulation.
316 [&] {
317 Value value = operand;
318 for (int i = 1; i < kSubgroupSize; i <<= 1) {
319 Value offset = create<arith::ConstantIntOp>(int32Type, i);
320 auto shuffleOp =
321 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
322 gpu::ShuffleMode::XOR);
323 value = accumFactory(value, shuffleOp.getResult(0));
324 }
325 return SmallVector<Value, 1>{value};
326 });
327 return rewriter.getInsertionBlock()->getArgument(0);
328 }
329
330 /// Returns value divided by the subgroup size (i.e. 32).
331 Value getDivideBySubgroupSize(Value value) {
332 Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
333 return create<arith::DivSIOp>(int32Type, value, subgroupSize);
334 }
335
336 gpu::GPUFuncOp funcOp;
337 gpu::AllReduceOp reduceOp;
338 PatternRewriter &rewriter;
339
340 Location loc;
341 Type valueType;
342 Type indexType;
343 IntegerType int32Type;
344
345 static constexpr int kSubgroupSize = 32;
346};
347
348struct GpuAllReduceRewrite : public RewritePattern {
349 explicit GpuAllReduceRewrite(MLIRContext *context)
350 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
351
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 auto funcOp = cast<gpu::GPUFuncOp>(op);
355
356 SmallVector<gpu::AllReduceOp> reduceOps;
357 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
358 if (!reduceOp.getUniform())
359 return WalkResult::interrupt();
360
361 reduceOps.emplace_back(reduceOp);
362 return WalkResult::advance();
363 };
364
365 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
366 return rewriter.notifyMatchFailure(
367 op, "Non uniform reductions are not supported yet.");
368
369 for (gpu::AllReduceOp reduceOp : reduceOps)
370 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
371
372 return success();
373 }
374};
375} // namespace
376
378 patterns.add<GpuAllReduceRewrite>(patterns.getContext());
379}
return success()
lhs
b getContext())
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
iterator begin()
Definition Block.h:153
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
BlockArgument getArgument(unsigned i)
Definition Region.h:124
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...