MLIR  19.0.0git
AllReduceLowering.cpp
Go to the documentation of this file.
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/Support/ErrorHandling.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 struct GpuAllReduceRewriter {
31  using AccumulatorFactory = std::function<Value(Value, Value)>;
32 
33  GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
34  PatternRewriter &rewriter)
35  : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
36  loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
37  indexType(IndexType::get(reduceOp.getContext())),
38  int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
39 
40  /// Creates an all_reduce across the workgroup.
41  ///
42  /// First reduce the elements within a subgroup. The first invocation of each
43  /// subgroup writes the intermediate result to workgroup memory. After
44  /// synchronizing the workgroup, the first subgroup reduces the values from
45  /// workgroup memory. The result is broadcasted to all invocations through
46  /// workgroup memory.
47  ///
48  /// %subgroup_reduce = `createSubgroupReduce(%operand)`
49  /// cf.cond_br %is_first_lane, ^then1, ^continue1
50  /// ^then1:
51  /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
52  /// cf.br ^continue1
53  /// ^continue1:
54  /// gpu.barrier
55  /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
56  /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
57  /// ^then2:
58  /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
59  /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
60  /// store %all_reduce, %workgroup_buffer[%zero]
61  /// llvm.br ^continue2
62  /// ^continue2:
63  /// gpu.barrier
64  /// %result = load %workgroup_buffer[%zero]
65  /// return %result
66  ///
67  void rewrite() {
68  rewriter.setInsertionPoint(reduceOp);
69 
70  // Compute linear invocation index and workgroup size.
71  Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
72  Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
73  Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
74  Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
75  Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
76  Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
77  Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
78  Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
79  Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
80  Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
81  Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
82  Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
83 
84  // Compute lane id (invocation id withing the subgroup).
85  Value subgroupMask =
86  create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
87  Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
88  Value isFirstLane =
89  create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
90  create<arith::ConstantIntOp>(0, int32Type));
91 
92  Value numThreadsWithSmallerSubgroupId =
93  create<arith::SubIOp>(invocationIdx, laneId);
94  // The number of active invocations starting from the current subgroup.
95  // The consumers do not require the value to be clamped to the size of the
96  // subgroup.
97  Value activeWidth =
98  create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
99 
100  // Create factory for op which accumulates to values.
101  AccumulatorFactory accumFactory = getFactory();
102  assert(accumFactory && "failed to create accumulator factory");
103 
104  // Reduce elements within each subgroup to produce the intermediate results.
105  Value subgroupReduce = createSubgroupReduce(
106  activeWidth, laneId, reduceOp.getValue(), accumFactory);
107 
108  // Add workgroup buffer to parent function for intermediate result.
109  Value buffer = createWorkgroupBuffer();
110 
111  // Write the intermediate results to workgroup memory, using the first lane
112  // of each subgroup.
113  createPredicatedBlock(isFirstLane, [&] {
114  Value subgroupId = getDivideBySubgroupSize(invocationIdx);
115  Value index = create<arith::IndexCastOp>(indexType, subgroupId);
116  create<memref::StoreOp>(subgroupReduce, buffer, index);
117  });
118  create<gpu::BarrierOp>();
119 
120  // Compute number of active subgroups.
121  Value biasedBlockSize =
122  create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
123  Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
124  Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
125  invocationIdx, numSubgroups);
126 
127  // Use the first numSubgroups invocations to reduce the intermediate results
128  // from workgroup memory. The final result is written to workgroup memory
129  // again.
130  Value zero = create<arith::ConstantIndexOp>(0);
131  createPredicatedBlock(isValidSubgroup, [&] {
132  Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
133  Value value = create<memref::LoadOp>(valueType, buffer, index);
134  Value result =
135  createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
136  create<memref::StoreOp>(result, buffer, zero);
137  });
138 
139  // Synchronize workgroup and load result from workgroup memory.
140  create<gpu::BarrierOp>();
141  Value result = create<memref::LoadOp>(valueType, buffer, zero);
142 
143  rewriter.replaceOp(reduceOp, result);
144  }
145 
146 private:
147  // Shortcut to create an op from rewriter using loc as the first argument.
148  template <typename T, typename... Args>
149  T create(Args... args) {
150  return rewriter.create<T>(loc, std::forward<Args>(args)...);
151  }
152 
153  // Creates dimension op of type T, with the result casted to int32.
154  template <typename T>
155  Value getDimOp(gpu::Dimension dimension) {
156  Value dim = create<T>(indexType, dimension);
157  return create<arith::IndexCastOp>(int32Type, dim);
158  }
159 
160  /// Adds type to funcOp's workgroup attributions.
161  Value createWorkgroupBuffer() {
162  // TODO: Pick a proper location for the attribution.
163  auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
164  funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
165  auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
166  workgroupMemoryAddressSpace);
167  return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
168  }
169 
170  /// Returns an accumulator factory using either the op attribute or the body
171  /// region.
172  AccumulatorFactory getFactory() {
173  auto &body = reduceOp.getBody();
174  if (!body.empty())
175  return getFactory(body);
176  auto opAttr = reduceOp.getOp();
177  if (opAttr)
178  return getFactory(*opAttr);
179  return AccumulatorFactory();
180  }
181 
182  /// Returns an accumulator factory that clones the body. The body's entry
183  /// block is expected to have 2 arguments. The gpu.yield return the
184  /// accumulated value of the same type.
185  AccumulatorFactory getFactory(Region &body) {
186  return [&body, this](Value lhs, Value rhs) -> Value {
187  Block *block = rewriter.getInsertionBlock();
188  Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
189 
190  // Insert accumulator body between split block.
191  IRMapping mapping;
192  mapping.map(body.getArgument(0), lhs);
193  mapping.map(body.getArgument(1), rhs);
194  rewriter.cloneRegionBefore(body, *split->getParent(),
195  split->getIterator(), mapping);
196 
197  // Add branch before inserted body, into body.
198  block = block->getNextNode();
199  create<cf::BranchOp>(block, ValueRange());
200 
201  // Replace all gpu.yield ops with branch out of body.
202  for (; block != split; block = block->getNextNode()) {
203  Operation *terminator = block->getTerminator();
204  if (!isa<gpu::YieldOp>(terminator))
205  continue;
206  rewriter.setInsertionPointToEnd(block);
207  rewriter.replaceOpWithNewOp<cf::BranchOp>(
208  terminator, split, ValueRange(terminator->getOperand(0)));
209  }
210 
211  // Return accumulator result.
212  rewriter.setInsertionPointToStart(split);
213  return split->addArgument(lhs.getType(), lhs.getLoc());
214  };
215  }
216 
217  /// Returns an accumulator factory that creates an op specified by opName.
218  AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
219  return [opName, this](Value lhs, Value rhs) {
220  return vector::makeArithReduction(rewriter, loc,
221  convertReductionKind(opName), lhs, rhs);
222  };
223  }
224 
225  /// Creates an if-block skeleton and calls the two factories to generate the
226  /// ops in the `then` and `else` block..
227  ///
228  /// llvm.cond_br %condition, ^then, ^continue
229  /// ^then:
230  /// %then_operands = `thenOpsFactory()`
231  /// llvm.br ^continue(%then_operands)
232  /// ^else:
233  /// %else_operands = `elseOpsFactory()`
234  /// llvm.br ^continue(%else_operands)
235  /// ^continue(%block_operands):
236  ///
237  template <typename ThenOpsFactory, typename ElseOpsFactory>
238  void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
239  ElseOpsFactory &&elseOpsFactory) {
240  Block *currentBlock = rewriter.getInsertionBlock();
241  auto currentPoint = rewriter.getInsertionPoint();
242 
243  Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
244  Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
245  Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
246 
247  rewriter.setInsertionPointToEnd(currentBlock);
248  create<cf::CondBranchOp>(condition, thenBlock,
249  /*trueOperands=*/ArrayRef<Value>(), elseBlock,
250  /*falseOperands=*/ArrayRef<Value>());
251 
252  rewriter.setInsertionPointToStart(thenBlock);
253  auto thenOperands = thenOpsFactory();
254  create<cf::BranchOp>(continueBlock, thenOperands);
255 
256  rewriter.setInsertionPointToStart(elseBlock);
257  auto elseOperands = elseOpsFactory();
258  create<cf::BranchOp>(continueBlock, elseOperands);
259 
260  assert(thenOperands.size() == elseOperands.size());
261  rewriter.setInsertionPointToStart(continueBlock);
262  for (auto operand : thenOperands)
263  continueBlock->addArgument(operand.getType(), operand.getLoc());
264  }
265 
266  /// Shortcut for createIf with empty else block and no block operands.
267  template <typename Factory>
268  void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
269  static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
270  "predicatedOpsFactory should not return any value");
271  createIf(
272  condition,
273  [&] {
274  predicatedOpsFactory();
275  return ArrayRef<Value>();
276  },
277  [&] { return ArrayRef<Value>(); });
278  }
279 
280  /// Creates a reduction across the first activeWidth lanes of a subgroup, or
281  /// the entire subgroup if activeWidth is larger than the subgroup width.
282  /// The first lane returns the result, all others return values are undefined.
283  Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
284  AccumulatorFactory &accumFactory) {
285  Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
286  Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
287  activeWidth, subgroupSize);
288  std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
289 
290  createIf(
291  isPartialSubgroup,
292  // Generate reduction over a (potentially) partial subgroup.
293  [&] {
294  Value value = operand;
295  // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
296  // lane is within the active range. The accumulated value is available
297  // in the first lane.
298  for (int i = 1; i < kSubgroupSize; i <<= 1) {
299  Value offset = create<arith::ConstantIntOp>(i, int32Type);
300  auto shuffleOp = create<gpu::ShuffleOp>(
301  shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
302  // Skip the accumulation if the shuffle op read from a lane outside
303  // of the active range.
304  createIf(
305  shuffleOp.getResult(1),
306  [&] {
307  return SmallVector<Value, 1>{
308  accumFactory(value, shuffleOp.getResult(0))};
309  },
310  [&] { return llvm::ArrayRef(value); });
311  value = rewriter.getInsertionBlock()->getArgument(0);
312  }
313  return SmallVector<Value, 1>{value};
314  },
315  // Generate a reduction over the entire subgroup. This is a
316  // specialization of the above reduction with unconditional
317  // accumulation.
318  [&] {
319  Value value = operand;
320  for (int i = 1; i < kSubgroupSize; i <<= 1) {
321  Value offset = create<arith::ConstantIntOp>(i, int32Type);
322  auto shuffleOp =
323  create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
324  gpu::ShuffleMode::XOR);
325  value = accumFactory(value, shuffleOp.getResult(0));
326  }
327  return SmallVector<Value, 1>{value};
328  });
329  return rewriter.getInsertionBlock()->getArgument(0);
330  }
331 
332  /// Returns value divided by the subgroup size (i.e. 32).
333  Value getDivideBySubgroupSize(Value value) {
334  Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
335  return create<arith::DivSIOp>(int32Type, value, subgroupSize);
336  }
337 
338  gpu::GPUFuncOp funcOp;
339  gpu::AllReduceOp reduceOp;
340  PatternRewriter &rewriter;
341 
342  Location loc;
343  Type valueType;
344  Type indexType;
345  IntegerType int32Type;
346 
347  static constexpr int kSubgroupSize = 32;
348 };
349 
350 struct GpuAllReduceRewrite : public RewritePattern {
351  explicit GpuAllReduceRewrite(MLIRContext *context)
352  : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
353 
354  LogicalResult matchAndRewrite(Operation *op,
355  PatternRewriter &rewriter) const override {
356  auto funcOp = cast<gpu::GPUFuncOp>(op);
357 
359  auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
360  if (!reduceOp.getUniform())
361  return WalkResult::interrupt();
362 
363  reduceOps.emplace_back(reduceOp);
364  return WalkResult::advance();
365  };
366 
367  if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
368  return rewriter.notifyMatchFailure(
369  op, "Non uniform reductions are not supported yet.");
370 
371  for (gpu::AllReduceOp reduceOp : reduceOps)
372  GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
373 
374  return success();
375  }
376 };
377 } // namespace
378 
380  patterns.add<GpuAllReduceRewrite>(patterns.getContext());
381 }
static MLIRContext * getContext(OpFoldResult val)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
Block represents an ordered list of Operations.
Definition: Block.h:30
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:307
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
iterator begin()
Definition: Block.h:140
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateGpuAllReducePatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26