MLIR  20.0.0git
SubgroupReduceLowering.cpp
Go to the documentation of this file.
1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements gradual lowering of `gpu.subgroup_reduce` ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
24 #include <cassert>
25 #include <cstdint>
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Example, assumes `maxShuffleBitwidth` equal to 32:
32 /// ```
33 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
34 /// ==>
35 /// %v0 = arith.constant dense<0.0> : vector<3xf16>
36 /// %e0 = vector.extract_strided_slice %x
37 /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
38 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
39 /// %v1 = vector.insert_strided_slice %r0, %v0
40 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
41 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
42 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
43 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
44 /// ```
45 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
46  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
47  PatternBenefit benefit)
48  : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
49  }
50 
51  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
52  PatternRewriter &rewriter) const override {
53  auto vecTy = dyn_cast<VectorType>(op.getType());
54  if (!vecTy || vecTy.getNumElements() < 2)
55  return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
56 
57  assert(vecTy.getRank() == 1 && "Unexpected vector type");
58  assert(!vecTy.isScalable() && "Unexpected vector type");
59 
60  Type elemTy = vecTy.getElementType();
61  unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
62  if (elemBitwidth >= maxShuffleBitwidth)
63  return rewriter.notifyMatchFailure(
64  op, llvm::formatv("element type too large ({0}), cannot break down "
65  "into vectors of bitwidth {1} or less",
66  elemBitwidth, maxShuffleBitwidth));
67 
68  unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
69  assert(elementsPerShuffle >= 1);
70 
71  unsigned numNewReductions =
72  llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
73  assert(numNewReductions >= 1);
74  if (numNewReductions == 1)
75  return rewriter.notifyMatchFailure(op, "nothing to break down");
76 
77  Location loc = op.getLoc();
78  Value res =
79  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
80 
81  for (unsigned i = 0; i != numNewReductions; ++i) {
82  int64_t startIdx = i * elementsPerShuffle;
83  int64_t endIdx =
84  std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
85  int64_t numElems = endIdx - startIdx;
86 
87  Value extracted;
88  if (numElems == 1) {
89  extracted =
90  rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
91  } else {
92  extracted = rewriter.create<vector::ExtractStridedSliceOp>(
93  loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
94  /*strides=*/1);
95  }
96 
97  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
98  loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
99  op.getClusterStride());
100  if (numElems == 1) {
101  res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
102  continue;
103  }
104 
105  res = rewriter.create<vector::InsertStridedSliceOp>(
106  loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
107  }
108 
109  rewriter.replaceOp(op, res);
110  return success();
111  }
112 
113 private:
114  unsigned maxShuffleBitwidth = 0;
115 };
116 
117 /// Example:
118 /// ```
119 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
120 /// ==>
121 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
122 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
123 /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
124 /// ```
125 struct ScalarizeSingleElementReduce final
126  : OpRewritePattern<gpu::SubgroupReduceOp> {
128 
129  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
130  PatternRewriter &rewriter) const override {
131  auto vecTy = dyn_cast<VectorType>(op.getType());
132  if (!vecTy || vecTy.getNumElements() != 1)
133  return rewriter.notifyMatchFailure(op, "not a single-element reduction");
134 
135  assert(vecTy.getRank() == 1 && "Unexpected vector type");
136  assert(!vecTy.isScalable() && "Unexpected vector type");
137  Location loc = op.getLoc();
138  Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
139  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
140  loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
141  op.getClusterStride());
142  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
143  return success();
144  }
145 };
146 
147 struct ClusterInfo {
148  unsigned clusterStride;
149  unsigned clusterSize;
150  unsigned subgroupSize;
151 };
152 
153 static FailureOr<ClusterInfo>
154 getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
155  assert(llvm::isPowerOf2_32(subgroupSize));
156 
157  std::optional<uint32_t> clusterSize = op.getClusterSize();
158  assert(!clusterSize ||
159  llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
160  if (clusterSize && *clusterSize > subgroupSize)
161  return op.emitOpError()
162  << "cluster size " << *clusterSize
163  << " is greater than subgroup size " << subgroupSize;
164  unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
165 
166  auto clusterStride = op.getClusterStride();
167  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
168  if (clusterStride >= subgroupSize)
169  return op.emitOpError()
170  << "cluster stride " << clusterStride
171  << " is not less than subgroup size " << subgroupSize;
172 
173  return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
174 }
175 
176 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
177 /// and `unpackFn` to convert to the native shuffle type and to the reduction
178 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
179 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
180 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
181 /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
182 /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
183 /// lanes within a cluster, reducing all lanes in each cluster in parallel.
184 Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
185  Value input, gpu::AllReduceOperation mode,
186  const ClusterInfo &ci,
187  function_ref<Value(Value)> packFn,
188  function_ref<Value(Value)> unpackFn) {
189  // Lane value always stays in the original type. We use it to perform arith
190  // reductions.
191  Value laneVal = input;
192  // Parallel reduction using butterfly shuffles.
193  for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
194  i <<= 1) {
195  Value shuffled = builder
196  .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
197  /*width=*/ci.subgroupSize,
198  /*mode=*/gpu::ShuffleMode::XOR)
199  .getShuffleResult();
200  laneVal = vector::makeArithReduction(builder, loc,
202  laneVal, unpackFn(shuffled));
203  assert(laneVal.getType() == input.getType());
204  }
205 
206  return laneVal;
207 }
208 
209 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
210 struct ScalarSubgroupReduceToShuffles final
211  : OpRewritePattern<gpu::SubgroupReduceOp> {
212  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
213  unsigned shuffleBitwidth, bool matchClustered,
214  PatternBenefit benefit)
215  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216  shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
217 
218  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
219  PatternRewriter &rewriter) const override {
220  if (op.getClusterSize().has_value() != matchClustered) {
221  return rewriter.notifyMatchFailure(
222  op, llvm::formatv("op is {0}clustered but pattern is configured to "
223  "only match {1}clustered ops",
224  matchClustered ? "non-" : "",
225  matchClustered ? "" : "non-"));
226  }
227 
228  auto ci = getAndValidateClusterInfo(op, subgroupSize);
229  if (failed(ci))
230  return failure();
231 
232  Type valueTy = op.getType();
233  unsigned elemBitwidth =
235  if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
236  return rewriter.notifyMatchFailure(
237  op, "value type is not a compatible scalar");
238 
239  Location loc = op.getLoc();
240  // Since this is already a native shuffle scalar, no packing is necessary.
241  if (elemBitwidth == shuffleBitwidth) {
242  auto identityFn = [](Value v) { return v; };
243  rewriter.replaceOp(op, createSubgroupShuffleReduction(
244  rewriter, loc, op.getValue(), op.getOp(), *ci,
245  identityFn, identityFn));
246  return success();
247  }
248 
249  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
250  auto equivIntType = rewriter.getIntegerType(elemBitwidth);
251  auto packFn = [loc, &rewriter, equivIntType,
252  shuffleIntType](Value unpackedVal) -> Value {
253  auto asInt =
254  rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
255  return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
256  };
257  auto unpackFn = [loc, &rewriter, equivIntType,
258  valueTy](Value packedVal) -> Value {
259  auto asInt =
260  rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
261  return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
262  };
263 
264  rewriter.replaceOp(
265  op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
266  op.getOp(), *ci, packFn, unpackFn));
267  return success();
268  }
269 
270 private:
271  unsigned subgroupSize = 0;
272  unsigned shuffleBitwidth = 0;
273  bool matchClustered = false;
274 };
275 
276 /// Lowers vector gpu subgroup reductions to a series of shuffles.
277 struct VectorSubgroupReduceToShuffles final
278  : OpRewritePattern<gpu::SubgroupReduceOp> {
279  VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
280  unsigned shuffleBitwidth, bool matchClustered,
281  PatternBenefit benefit)
282  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
283  shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
284 
285  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
286  PatternRewriter &rewriter) const override {
287  if (op.getClusterSize().has_value() != matchClustered) {
288  return rewriter.notifyMatchFailure(
289  op, llvm::formatv("op is {0}clustered but pattern is configured to "
290  "only match {1}clustered ops",
291  matchClustered ? "non-" : "",
292  matchClustered ? "" : "non-"));
293  }
294 
295  auto ci = getAndValidateClusterInfo(op, subgroupSize);
296  if (failed(ci))
297  return failure();
298 
299  auto vecTy = dyn_cast<VectorType>(op.getType());
300  if (!vecTy)
301  return rewriter.notifyMatchFailure(op, "value type is not a vector");
302 
303  unsigned vecBitwidth =
304  vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
305  if (vecBitwidth > shuffleBitwidth)
306  return rewriter.notifyMatchFailure(
307  op,
308  llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
309  "to shuffles of size {1}",
310  vecBitwidth, shuffleBitwidth));
311 
312  unsigned elementsPerShuffle =
313  shuffleBitwidth / vecTy.getElementTypeBitWidth();
314  if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
315  return rewriter.notifyMatchFailure(
316  op, "shuffle bitwidth is not a multiple of the element bitwidth");
317 
318  Location loc = op.getLoc();
319 
320  // If the reduced type is smaller than the native shuffle size, extend it,
321  // perform the shuffles, and extract at the end.
322  auto extendedVecTy = VectorType::get(
323  static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
324  Value extendedInput = op.getValue();
325  if (vecBitwidth < shuffleBitwidth) {
326  auto zero = rewriter.create<arith::ConstantOp>(
327  loc, rewriter.getZeroAttr(extendedVecTy));
328  extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
329  loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
330  }
331 
332  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
333  auto shuffleVecType = VectorType::get(1, shuffleIntType);
334 
335  auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
336  auto asIntVec =
337  rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
338  return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
339  };
340  auto unpackFn = [loc, &rewriter, shuffleVecType,
341  extendedVecTy](Value packedVal) -> Value {
342  auto asIntVec =
343  rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
344  return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
345  };
346 
347  Value res = createSubgroupShuffleReduction(
348  rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
349 
350  if (vecBitwidth < shuffleBitwidth) {
351  res = rewriter.create<vector::ExtractStridedSliceOp>(
352  loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
353  /*strides=*/1);
354  }
355 
356  rewriter.replaceOp(op, res);
357  return success();
358  }
359 
360 private:
361  unsigned subgroupSize = 0;
362  unsigned shuffleBitwidth = 0;
363  bool matchClustered = false;
364 };
365 } // namespace
366 
368  RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
369  PatternBenefit benefit) {
370  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
371  maxShuffleBitwidth, benefit);
372  patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
373 }
374 
376  RewritePatternSet &patterns, unsigned subgroupSize,
377  unsigned shuffleBitwidth, PatternBenefit benefit) {
378  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
379  patterns.getContext(), subgroupSize, shuffleBitwidth,
380  /*matchClustered=*/false, benefit);
381 }
382 
384  RewritePatternSet &patterns, unsigned subgroupSize,
385  unsigned shuffleBitwidth, PatternBenefit benefit) {
386  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387  patterns.getContext(), subgroupSize, shuffleBitwidth,
388  /*matchClustered=*/true, benefit);
389 }
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2817
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362