MLIR  22.0.0git
AllReduceLowering.cpp
Go to the documentation of this file.
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 struct GpuAllReduceRewriter {
29  using AccumulatorFactory = std::function<Value(Value, Value)>;
30 
31  GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32  PatternRewriter &rewriter)
33  : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34  loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
35  indexType(IndexType::get(reduceOp.getContext())),
36  int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
37 
38  /// Creates an all_reduce across the workgroup.
39  ///
40  /// First reduce the elements within a subgroup. The first invocation of each
41  /// subgroup writes the intermediate result to workgroup memory. After
42  /// synchronizing the workgroup, the first subgroup reduces the values from
43  /// workgroup memory. The result is broadcasted to all invocations through
44  /// workgroup memory.
45  ///
46  /// %subgroup_reduce = `createSubgroupReduce(%operand)`
47  /// cf.cond_br %is_first_lane, ^then1, ^continue1
48  /// ^then1:
49  /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50  /// cf.br ^continue1
51  /// ^continue1:
52  /// gpu.barrier
53  /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54  /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
55  /// ^then2:
56  /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
57  /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
58  /// store %all_reduce, %workgroup_buffer[%zero]
59  /// llvm.br ^continue2
60  /// ^continue2:
61  /// gpu.barrier
62  /// %result = load %workgroup_buffer[%zero]
63  /// return %result
64  ///
65  void rewrite() {
66  rewriter.setInsertionPoint(reduceOp);
67 
68  // Compute linear invocation index and workgroup size.
69  Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70  Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71  Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72  Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73  Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74  Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75  Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76  Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77  Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78  Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79  Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80  Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
81 
82  // Compute lane id (invocation id withing the subgroup).
83  Value subgroupMask =
84  create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
85  Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
86  Value isFirstLane =
87  create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88  create<arith::ConstantIntOp>(int32Type, 0));
89 
90  Value numThreadsWithSmallerSubgroupId =
91  create<arith::SubIOp>(invocationIdx, laneId);
92  // The number of active invocations starting from the current subgroup.
93  // The consumers do not require the value to be clamped to the size of the
94  // subgroup.
95  Value activeWidth =
96  create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
97 
98  // Create factory for op which accumulates to values.
99  AccumulatorFactory accumFactory = getFactory();
100  assert(accumFactory && "failed to create accumulator factory");
101 
102  // Reduce elements within each subgroup to produce the intermediate results.
103  Value subgroupReduce = createSubgroupReduce(
104  activeWidth, laneId, reduceOp.getValue(), accumFactory);
105 
106  // Add workgroup buffer to parent function for intermediate result.
107  Value buffer = createWorkgroupBuffer();
108 
109  // Write the intermediate results to workgroup memory, using the first lane
110  // of each subgroup.
111  createPredicatedBlock(isFirstLane, [&] {
112  Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113  Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114  create<memref::StoreOp>(subgroupReduce, buffer, index);
115  });
116  create<gpu::BarrierOp>();
117 
118  // Compute number of active subgroups.
119  Value biasedBlockSize =
120  create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121  Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122  Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123  invocationIdx, numSubgroups);
124 
125  // Use the first numSubgroups invocations to reduce the intermediate results
126  // from workgroup memory. The final result is written to workgroup memory
127  // again.
128  Value zero = create<arith::ConstantIndexOp>(0);
129  createPredicatedBlock(isValidSubgroup, [&] {
130  Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131  Value value = create<memref::LoadOp>(valueType, buffer, index);
132  Value result =
133  createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134  create<memref::StoreOp>(result, buffer, zero);
135  });
136 
137  // Synchronize workgroup and load result from workgroup memory.
138  create<gpu::BarrierOp>();
139  Value result = create<memref::LoadOp>(valueType, buffer, zero);
140 
141  rewriter.replaceOp(reduceOp, result);
142  }
143 
144 private:
145  // Shortcut to create an op from rewriter using loc as the first argument.
146  template <typename T, typename... Args>
147  T create(Args... args) {
148  return T::create(rewriter, loc, std::forward<Args>(args)...);
149  }
150 
151  // Creates dimension op of type T, with the result casted to int32.
152  template <typename T>
153  Value getDimOp(gpu::Dimension dimension) {
154  Value dim = create<T>(indexType, dimension);
155  return create<arith::IndexCastOp>(int32Type, dim);
156  }
157 
158  /// Adds type to funcOp's workgroup attributions.
159  Value createWorkgroupBuffer() {
160  // TODO: Pick a proper location for the attribution.
161  auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
162  funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
163  auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
164  workgroupMemoryAddressSpace);
165  return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
166  }
167 
168  /// Returns an accumulator factory using either the op attribute or the body
169  /// region.
170  AccumulatorFactory getFactory() {
171  auto &body = reduceOp.getBody();
172  if (!body.empty())
173  return getFactory(body);
174  auto opAttr = reduceOp.getOp();
175  if (opAttr)
176  return getFactory(*opAttr);
177  return AccumulatorFactory();
178  }
179 
180  /// Returns an accumulator factory that clones the body. The body's entry
181  /// block is expected to have 2 arguments. The gpu.yield return the
182  /// accumulated value of the same type.
183  AccumulatorFactory getFactory(Region &body) {
184  return [&body, this](Value lhs, Value rhs) -> Value {
185  Block *block = rewriter.getInsertionBlock();
186  Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
187 
188  // Insert accumulator body between split block.
189  IRMapping mapping;
190  mapping.map(body.getArgument(0), lhs);
191  mapping.map(body.getArgument(1), rhs);
192  rewriter.cloneRegionBefore(body, *split->getParent(),
193  split->getIterator(), mapping);
194 
195  // Add branch before inserted body, into body.
196  block = block->getNextNode();
197  create<cf::BranchOp>(block, ValueRange());
198 
199  // Replace all gpu.yield ops with branch out of body.
200  for (; block != split; block = block->getNextNode()) {
201  Operation *terminator = block->getTerminator();
202  if (!isa<gpu::YieldOp>(terminator))
203  continue;
204  rewriter.setInsertionPointToEnd(block);
205  rewriter.replaceOpWithNewOp<cf::BranchOp>(
206  terminator, split, ValueRange(terminator->getOperand(0)));
207  }
208 
209  // Return accumulator result.
210  rewriter.setInsertionPointToStart(split);
211  return split->addArgument(lhs.getType(), lhs.getLoc());
212  };
213  }
214 
215  /// Returns an accumulator factory that creates an op specified by opName.
216  AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217  return [opName, this](Value lhs, Value rhs) {
218  return vector::makeArithReduction(rewriter, loc,
219  convertReductionKind(opName), lhs, rhs);
220  };
221  }
222 
223  /// Creates an if-block skeleton and calls the two factories to generate the
224  /// ops in the `then` and `else` block..
225  ///
226  /// llvm.cond_br %condition, ^then, ^continue
227  /// ^then:
228  /// %then_operands = `thenOpsFactory()`
229  /// llvm.br ^continue(%then_operands)
230  /// ^else:
231  /// %else_operands = `elseOpsFactory()`
232  /// llvm.br ^continue(%else_operands)
233  /// ^continue(%block_operands):
234  ///
235  template <typename ThenOpsFactory, typename ElseOpsFactory>
236  void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
237  ElseOpsFactory &&elseOpsFactory) {
238  Block *currentBlock = rewriter.getInsertionBlock();
239  auto currentPoint = rewriter.getInsertionPoint();
240 
241  Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
242  Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
243  Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
244 
245  rewriter.setInsertionPointToEnd(currentBlock);
246  create<cf::CondBranchOp>(condition, thenBlock,
247  /*trueOperands=*/ArrayRef<Value>(), elseBlock,
248  /*falseOperands=*/ArrayRef<Value>());
249 
250  rewriter.setInsertionPointToStart(thenBlock);
251  auto thenOperands = thenOpsFactory();
252  create<cf::BranchOp>(continueBlock, thenOperands);
253 
254  rewriter.setInsertionPointToStart(elseBlock);
255  auto elseOperands = elseOpsFactory();
256  create<cf::BranchOp>(continueBlock, elseOperands);
257 
258  assert(thenOperands.size() == elseOperands.size());
259  rewriter.setInsertionPointToStart(continueBlock);
260  for (auto operand : thenOperands)
261  continueBlock->addArgument(operand.getType(), operand.getLoc());
262  }
263 
264  /// Shortcut for createIf with empty else block and no block operands.
265  template <typename Factory>
266  void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
267  static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
268  "predicatedOpsFactory should not return any value");
269  createIf(
270  condition,
271  [&] {
272  predicatedOpsFactory();
273  return ArrayRef<Value>();
274  },
275  [&] { return ArrayRef<Value>(); });
276  }
277 
278  /// Creates a reduction across the first activeWidth lanes of a subgroup, or
279  /// the entire subgroup if activeWidth is larger than the subgroup width.
280  /// The first lane returns the result, all others return values are undefined.
281  Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
282  AccumulatorFactory &accumFactory) {
283  Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
284  Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
285  activeWidth, subgroupSize);
286  std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
287 
288  createIf(
289  isPartialSubgroup,
290  // Generate reduction over a (potentially) partial subgroup.
291  [&] {
292  Value value = operand;
293  // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
294  // lane is within the active range. The accumulated value is available
295  // in the first lane.
296  for (int i = 1; i < kSubgroupSize; i <<= 1) {
297  Value offset = create<arith::ConstantIntOp>(int32Type, i);
298  auto shuffleOp = create<gpu::ShuffleOp>(
299  shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
300  // Skip the accumulation if the shuffle op read from a lane outside
301  // of the active range.
302  createIf(
303  shuffleOp.getResult(1),
304  [&] {
305  return SmallVector<Value, 1>{
306  accumFactory(value, shuffleOp.getResult(0))};
307  },
308  [&] { return llvm::ArrayRef(value); });
309  value = rewriter.getInsertionBlock()->getArgument(0);
310  }
311  return SmallVector<Value, 1>{value};
312  },
313  // Generate a reduction over the entire subgroup. This is a
314  // specialization of the above reduction with unconditional
315  // accumulation.
316  [&] {
317  Value value = operand;
318  for (int i = 1; i < kSubgroupSize; i <<= 1) {
319  Value offset = create<arith::ConstantIntOp>(int32Type, i);
320  auto shuffleOp =
321  create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
322  gpu::ShuffleMode::XOR);
323  value = accumFactory(value, shuffleOp.getResult(0));
324  }
325  return SmallVector<Value, 1>{value};
326  });
327  return rewriter.getInsertionBlock()->getArgument(0);
328  }
329 
330  /// Returns value divided by the subgroup size (i.e. 32).
331  Value getDivideBySubgroupSize(Value value) {
332  Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
333  return create<arith::DivSIOp>(int32Type, value, subgroupSize);
334  }
335 
336  gpu::GPUFuncOp funcOp;
337  gpu::AllReduceOp reduceOp;
338  PatternRewriter &rewriter;
339 
340  Location loc;
341  Type valueType;
342  Type indexType;
343  IntegerType int32Type;
344 
345  static constexpr int kSubgroupSize = 32;
346 };
347 
348 struct GpuAllReduceRewrite : public RewritePattern {
349  explicit GpuAllReduceRewrite(MLIRContext *context)
350  : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
351 
352  LogicalResult matchAndRewrite(Operation *op,
353  PatternRewriter &rewriter) const override {
354  auto funcOp = cast<gpu::GPUFuncOp>(op);
355 
357  auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
358  if (!reduceOp.getUniform())
359  return WalkResult::interrupt();
360 
361  reduceOps.emplace_back(reduceOp);
362  return WalkResult::advance();
363  };
364 
365  if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
366  return rewriter.notifyMatchFailure(
367  op, "Non uniform reductions are not supported yet.");
368 
369  for (gpu::AllReduceOp reduceOp : reduceOps)
370  GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
371 
372  return success();
373  }
374 };
375 } // namespace
376 
378  patterns.add<GpuAllReduceRewrite>(patterns.getContext());
379 }
static MLIRContext * getContext(OpFoldResult val)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Block represents an ordered list of Operations.
Definition: Block.h:33
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:308
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
iterator begin()
Definition: Block.h:143
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
constexpr unsigned subgroupSize
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet & patterns
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...