MLIR 22.0.0git
AllReduceLowering.cpp
Go to the documentation of this file.
1//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect lowering of the all-reduce op to a block of
10// simpler instructions.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
23
24using namespace mlir;
25
26namespace {
27
28struct GpuAllReduceRewriter {
29 using AccumulatorFactory = std::function<Value(Value, Value)>;
30
31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32 PatternRewriter &rewriter)
33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
35 indexType(IndexType::get(reduceOp.getContext())),
36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
37
38 /// Creates an all_reduce across the workgroup.
39 ///
40 /// First reduce the elements within a subgroup. The first invocation of each
41 /// subgroup writes the intermediate result to workgroup memory. After
42 /// synchronizing the workgroup, the first subgroup reduces the values from
43 /// workgroup memory. The result is broadcasted to all invocations through
44 /// workgroup memory.
45 ///
46 /// %subgroup_reduce = `createSubgroupReduce(%operand)`
47 /// cf.cond_br %is_first_lane, ^then1, ^continue1
48 /// ^then1:
49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50 /// cf.br ^continue1
51 /// ^continue1:
52 /// gpu.barrier
53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
55 /// ^then2:
56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
58 /// store %all_reduce, %workgroup_buffer[%zero]
59 /// llvm.br ^continue2
60 /// ^continue2:
61 /// gpu.barrier
62 /// %result = load %workgroup_buffer[%zero]
63 /// return %result
64 ///
65 void rewrite() {
66 rewriter.setInsertionPoint(reduceOp);
67
68 // Compute linear invocation index and workgroup size.
69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
81
82 // Compute lane id (invocation id withing the subgroup).
83 Value subgroupMask =
84 create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
86 Value isFirstLane =
87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88 create<arith::ConstantIntOp>(int32Type, 0));
89
90 Value numThreadsWithSmallerSubgroupId =
91 create<arith::SubIOp>(invocationIdx, laneId);
92 // The number of active invocations starting from the current subgroup.
93 // The consumers do not require the value to be clamped to the size of the
94 // subgroup.
95 Value activeWidth =
96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
97
98 // Create factory for op which accumulates to values.
99 AccumulatorFactory accumFactory = getFactory();
100 assert(accumFactory && "failed to create accumulator factory");
101
102 // Reduce elements within each subgroup to produce the intermediate results.
103 Value subgroupReduce = createSubgroupReduce(
104 activeWidth, laneId, reduceOp.getValue(), accumFactory);
105
106 // Add workgroup buffer to parent function for intermediate result.
107 Value buffer = createWorkgroupBuffer();
108
109 // Write the intermediate results to workgroup memory, using the first lane
110 // of each subgroup.
111 createPredicatedBlock(isFirstLane, [&] {
112 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114 create<memref::StoreOp>(subgroupReduce, buffer, index);
115 });
116 create<gpu::BarrierOp>();
117
118 // Compute number of active subgroups.
119 Value biasedBlockSize =
120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123 invocationIdx, numSubgroups);
124
125 // Use the first numSubgroups invocations to reduce the intermediate results
126 // from workgroup memory. The final result is written to workgroup memory
127 // again.
128 Value zero = create<arith::ConstantIndexOp>(0);
129 createPredicatedBlock(isValidSubgroup, [&] {
130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131 Value value = create<memref::LoadOp>(valueType, buffer, index);
132 Value result =
133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134 create<memref::StoreOp>(result, buffer, zero);
135 });
136
137 // Synchronize workgroup and load result from workgroup memory.
138 create<gpu::BarrierOp>();
139 Value result = create<memref::LoadOp>(valueType, buffer, zero);
140
141 rewriter.replaceOp(reduceOp, result);
142 }
143
144private:
145 // Shortcut to create an op from rewriter using loc as the first argument.
146 template <typename T, typename... Args>
147 T create(Args... args) {
148 return T::create(rewriter, loc, std::forward<Args>(args)...);
149 }
150
151 // Creates dimension op of type T, with the result casted to int32.
152 template <typename T>
153 Value getDimOp(gpu::Dimension dimension) {
154 Value dim = create<T>(indexType, dimension);
155 return create<arith::IndexCastOp>(int32Type, dim);
156 }
157
158 /// Adds type to funcOp's workgroup attributions.
159 Value createWorkgroupBuffer() {
160 // TODO: Pick a proper location for the attribution.
161 auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
162 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
163 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
164 workgroupMemoryAddressSpace);
165 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
166 }
167
168 /// Returns an accumulator factory using either the op attribute or the body
169 /// region.
170 AccumulatorFactory getFactory() {
171 auto &body = reduceOp.getBody();
172 if (!body.empty())
173 return getFactory(body);
174 auto opAttr = reduceOp.getOp();
175 if (opAttr)
176 return getFactory(*opAttr);
177 return AccumulatorFactory();
178 }
179
180 /// Returns an accumulator factory that clones the body. The body's entry
181 /// block is expected to have 2 arguments. The gpu.yield return the
182 /// accumulated value of the same type.
183 AccumulatorFactory getFactory(Region &body) {
184 return [&body, this](Value lhs, Value rhs) -> Value {
185 Block *block = rewriter.getInsertionBlock();
186 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
187
188 // Insert accumulator body between split block.
189 IRMapping mapping;
190 mapping.map(body.getArgument(0), lhs);
191 mapping.map(body.getArgument(1), rhs);
192 rewriter.cloneRegionBefore(body, *split->getParent(),
193 split->getIterator(), mapping);
194
195 // Add branch before inserted body, into body.
196 block = block->getNextNode();
197 create<cf::BranchOp>(block, ValueRange());
198
199 // Replace all gpu.yield ops with branch out of body.
200 for (; block != split; block = block->getNextNode()) {
201 Operation *terminator = block->getTerminator();
202 if (!isa<gpu::YieldOp>(terminator))
203 continue;
204 rewriter.setInsertionPointToEnd(block);
205 rewriter.replaceOpWithNewOp<cf::BranchOp>(
206 terminator, split, ValueRange(terminator->getOperand(0)));
207 }
208
209 // Return accumulator result.
210 rewriter.setInsertionPointToStart(split);
211 return split->addArgument(lhs.getType(), lhs.getLoc());
212 };
213 }
214
215 /// Returns an accumulator factory that creates an op specified by opName.
216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217 return [opName, this](Value lhs, Value rhs) {
218 return vector::makeArithReduction(rewriter, loc,
219 convertReductionKind(opName), lhs, rhs);
220 };
221 }
222
223 /// Creates an if-block skeleton and calls the two factories to generate the
224 /// ops in the `then` and `else` block..
225 ///
226 /// llvm.cond_br %condition, ^then, ^continue
227 /// ^then:
228 /// %then_operands = `thenOpsFactory()`
229 /// llvm.br ^continue(%then_operands)
230 /// ^else:
231 /// %else_operands = `elseOpsFactory()`
232 /// llvm.br ^continue(%else_operands)
233 /// ^continue(%block_operands):
234 ///
235 template <typename ThenOpsFactory, typename ElseOpsFactory>
236 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
237 ElseOpsFactory &&elseOpsFactory) {
238 Block *currentBlock = rewriter.getInsertionBlock();
239 auto currentPoint = rewriter.getInsertionPoint();
240
241 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
242 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
243 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
244
245 rewriter.setInsertionPointToEnd(currentBlock);
246 create<cf::CondBranchOp>(condition, thenBlock,
247 /*trueOperands=*/ArrayRef<Value>(), elseBlock,
248 /*falseOperands=*/ArrayRef<Value>());
249
250 rewriter.setInsertionPointToStart(thenBlock);
251 auto thenOperands = thenOpsFactory();
252 create<cf::BranchOp>(continueBlock, thenOperands);
253
254 rewriter.setInsertionPointToStart(elseBlock);
255 auto elseOperands = elseOpsFactory();
256 create<cf::BranchOp>(continueBlock, elseOperands);
257
258 assert(thenOperands.size() == elseOperands.size());
259 rewriter.setInsertionPointToStart(continueBlock);
260 for (auto operand : thenOperands)
261 continueBlock->addArgument(operand.getType(), operand.getLoc());
262 }
263
264 /// Shortcut for createIf with empty else block and no block operands.
265 template <typename Factory>
266 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
267 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
268 "predicatedOpsFactory should not return any value");
269 createIf(
270 condition,
271 [&] {
272 predicatedOpsFactory();
273 return ArrayRef<Value>();
274 },
275 [&] { return ArrayRef<Value>(); });
276 }
277
278 /// Creates a reduction across the first activeWidth lanes of a subgroup, or
279 /// the entire subgroup if activeWidth is larger than the subgroup width.
280 /// The first lane returns the result, all others return values are undefined.
281 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
282 AccumulatorFactory &accumFactory) {
283 Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
284 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
285 activeWidth, subgroupSize);
286 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
287
288 createIf(
289 isPartialSubgroup,
290 // Generate reduction over a (potentially) partial subgroup.
291 [&] {
292 Value value = operand;
293 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
294 // lane is within the active range. The accumulated value is available
295 // in the first lane.
296 for (int i = 1; i < kSubgroupSize; i <<= 1) {
297 Value offset = create<arith::ConstantIntOp>(int32Type, i);
298 auto shuffleOp = create<gpu::ShuffleOp>(
299 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
300 // Skip the accumulation if the shuffle op read from a lane outside
301 // of the active range.
302 createIf(
303 shuffleOp.getResult(1),
304 [&] {
305 return SmallVector<Value, 1>{
306 accumFactory(value, shuffleOp.getResult(0))};
307 },
308 [&] { return llvm::ArrayRef(value); });
309 value = rewriter.getInsertionBlock()->getArgument(0);
310 }
311 return SmallVector<Value, 1>{value};
312 },
313 // Generate a reduction over the entire subgroup. This is a
314 // specialization of the above reduction with unconditional
315 // accumulation.
316 [&] {
317 Value value = operand;
318 for (int i = 1; i < kSubgroupSize; i <<= 1) {
319 Value offset = create<arith::ConstantIntOp>(int32Type, i);
320 auto shuffleOp =
321 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
322 gpu::ShuffleMode::XOR);
323 value = accumFactory(value, shuffleOp.getResult(0));
324 }
325 return SmallVector<Value, 1>{value};
326 });
327 return rewriter.getInsertionBlock()->getArgument(0);
328 }
329
330 /// Returns value divided by the subgroup size (i.e. 32).
331 Value getDivideBySubgroupSize(Value value) {
332 Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
333 return create<arith::DivSIOp>(int32Type, value, subgroupSize);
334 }
335
336 gpu::GPUFuncOp funcOp;
337 gpu::AllReduceOp reduceOp;
338 PatternRewriter &rewriter;
339
340 Location loc;
341 Type valueType;
342 Type indexType;
343 IntegerType int32Type;
344
345 static constexpr int kSubgroupSize = 32;
346};
347
348struct GpuAllReduceRewrite : public RewritePattern {
349 explicit GpuAllReduceRewrite(MLIRContext *context)
350 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
351
352 LogicalResult matchAndRewrite(Operation *op,
353 PatternRewriter &rewriter) const override {
354 auto funcOp = cast<gpu::GPUFuncOp>(op);
355
356 SmallVector<gpu::AllReduceOp> reduceOps;
357 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
358 if (!reduceOp.getUniform())
359 return WalkResult::interrupt();
360
361 reduceOps.emplace_back(reduceOp);
362 return WalkResult::advance();
363 };
364
365 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
366 return rewriter.notifyMatchFailure(
367 op, "Non uniform reductions are not supported yet.");
368
369 for (gpu::AllReduceOp reduceOp : reduceOps)
370 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
371
372 return success();
373 }
374};
375} // namespace
376
378 patterns.add<GpuAllReduceRewrite>(patterns.getContext());
379}
return success()
lhs
b getContext())
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition SCCP.cpp:67
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
iterator begin()
Definition Block.h:143
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
BlockArgument getArgument(unsigned i)
Definition Region.h:124
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
const FrozenRewritePatternSet & patterns
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...