MLIR  19.0.0git
SubgroupReduceLowering.cpp
Go to the documentation of this file.
1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements gradual lowering of `gpu.subgroup_reduce` ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <cstdint>
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Example, assumes `maxShuffleBitwidth` equal to 32:
33 /// ```
34 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
35 /// ==>
36 /// %v0 = arith.constant dense<0.0> : vector<3xf16>
37 /// %e0 = vector.extract_strided_slice %x
38 /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
39 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
40 /// %v1 = vector.insert_strided_slice %r0, %v0
41 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
42 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
43 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
44 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
45 /// ```
46 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
47  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
48  PatternBenefit benefit)
49  : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
50  }
51 
52  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
53  PatternRewriter &rewriter) const override {
54  auto vecTy = dyn_cast<VectorType>(op.getType());
55  if (!vecTy || vecTy.getNumElements() < 2)
56  return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
57 
58  assert(vecTy.getRank() == 1 && "Unexpected vector type");
59  assert(!vecTy.isScalable() && "Unexpected vector type");
60 
61  Type elemTy = vecTy.getElementType();
62  unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
63  if (elemBitwidth >= maxShuffleBitwidth)
64  return rewriter.notifyMatchFailure(
65  op, llvm::formatv("element type too large ({0}), cannot break down "
66  "into vectors of bitwidth {1} or less",
67  elemBitwidth, maxShuffleBitwidth));
68 
69  unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
70  assert(elementsPerShuffle >= 1);
71 
72  unsigned numNewReductions =
73  llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
74  assert(numNewReductions >= 1);
75  if (numNewReductions == 1)
76  return rewriter.notifyMatchFailure(op, "nothing to break down");
77 
78  Location loc = op.getLoc();
79  Value res =
80  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
81 
82  for (unsigned i = 0; i != numNewReductions; ++i) {
83  int64_t startIdx = i * elementsPerShuffle;
84  int64_t endIdx =
85  std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
86  int64_t numElems = endIdx - startIdx;
87 
88  Value extracted;
89  if (numElems == 1) {
90  extracted =
91  rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
92  } else {
93  extracted = rewriter.create<vector::ExtractStridedSliceOp>(
94  loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
95  /*strides=*/1);
96  }
97 
98  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
99  loc, extracted, op.getOp(), op.getUniform());
100  if (numElems == 1) {
101  res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
102  continue;
103  }
104 
105  res = rewriter.create<vector::InsertStridedSliceOp>(
106  loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
107  }
108 
109  rewriter.replaceOp(op, res);
110  return success();
111  }
112 
113 private:
114  unsigned maxShuffleBitwidth = 0;
115 };
116 
117 /// Example:
118 /// ```
119 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
120 /// ==>
121 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
122 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
123 /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
124 /// ```
125 struct ScalarizeSingleElementReduce final
126  : OpRewritePattern<gpu::SubgroupReduceOp> {
128 
129  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
130  PatternRewriter &rewriter) const override {
131  auto vecTy = dyn_cast<VectorType>(op.getType());
132  if (!vecTy || vecTy.getNumElements() != 1)
133  return rewriter.notifyMatchFailure(op, "not a single-element reduction");
134 
135  assert(vecTy.getRank() == 1 && "Unexpected vector type");
136  assert(!vecTy.isScalable() && "Unexpected vector type");
137  Location loc = op.getLoc();
138  Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
139  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
140  loc, extracted, op.getOp(), op.getUniform());
141  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
142  return success();
143  }
144 };
145 
146 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
147 /// and `unpackFn` to convert to the native shuffle type and to the reduction
148 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
149 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
150 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
151 /// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
152 static Value createSubgroupShuffleReduction(
153  OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
154  unsigned subgroupSize, function_ref<Value(Value)> packFn,
155  function_ref<Value(Value)> unpackFn) {
156  assert(llvm::isPowerOf2_32(subgroupSize));
157  // Lane value always stays in the original type. We use it to perform arith
158  // reductions.
159  Value laneVal = input;
160  // Parallel reduction using butterfly shuffles.
161  for (unsigned i = 1; i < subgroupSize; i <<= 1) {
162  Value shuffled = builder
163  .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
164  /*width=*/subgroupSize,
165  /*mode=*/gpu::ShuffleMode::XOR)
166  .getShuffleResult();
167  laneVal = vector::makeArithReduction(builder, loc,
169  laneVal, unpackFn(shuffled));
170  assert(laneVal.getType() == input.getType());
171  }
172 
173  return laneVal;
174 }
175 
176 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
177 struct ScalarSubgroupReduceToShuffles final
178  : OpRewritePattern<gpu::SubgroupReduceOp> {
179  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
180  unsigned shuffleBitwidth,
181  PatternBenefit benefit)
182  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
183  shuffleBitwidth(shuffleBitwidth) {}
184 
185  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
186  PatternRewriter &rewriter) const override {
187  Type valueTy = op.getType();
188  unsigned elemBitwidth =
190  if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
191  return rewriter.notifyMatchFailure(
192  op, "value type is not a compatible scalar");
193 
194  Location loc = op.getLoc();
195  // Since this is already a native shuffle scalar, no packing is necessary.
196  if (elemBitwidth == shuffleBitwidth) {
197  auto identityFn = [](Value v) { return v; };
198  rewriter.replaceOp(op, createSubgroupShuffleReduction(
199  rewriter, loc, op.getValue(), op.getOp(),
200  subgroupSize, identityFn, identityFn));
201  return success();
202  }
203 
204  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
205  auto equivIntType = rewriter.getIntegerType(elemBitwidth);
206  auto packFn = [loc, &rewriter, equivIntType,
207  shuffleIntType](Value unpackedVal) -> Value {
208  auto asInt =
209  rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
210  return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
211  };
212  auto unpackFn = [loc, &rewriter, equivIntType,
213  valueTy](Value packedVal) -> Value {
214  auto asInt =
215  rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
216  return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
217  };
218 
219  rewriter.replaceOp(op, createSubgroupShuffleReduction(
220  rewriter, loc, op.getValue(), op.getOp(),
221  subgroupSize, packFn, unpackFn));
222  return success();
223  }
224 
225 private:
226  unsigned subgroupSize = 0;
227  unsigned shuffleBitwidth = 0;
228 };
229 
230 /// Lowers vector gpu subgroup reductions to a series of shuffles.
231 struct VectorSubgroupReduceToShuffles final
232  : OpRewritePattern<gpu::SubgroupReduceOp> {
233  VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
234  unsigned shuffleBitwidth,
235  PatternBenefit benefit)
236  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
237  shuffleBitwidth(shuffleBitwidth) {}
238 
239  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
240  PatternRewriter &rewriter) const override {
241  auto vecTy = dyn_cast<VectorType>(op.getType());
242  if (!vecTy)
243  return rewriter.notifyMatchFailure(op, "value type is not a vector");
244 
245  unsigned vecBitwidth =
246  vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
247  if (vecBitwidth > shuffleBitwidth)
248  return rewriter.notifyMatchFailure(
249  op,
250  llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
251  "to shuffles of size {1}",
252  vecBitwidth, shuffleBitwidth));
253 
254  unsigned elementsPerShuffle =
255  shuffleBitwidth / vecTy.getElementTypeBitWidth();
256  if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
257  return rewriter.notifyMatchFailure(
258  op, "shuffle bitwidth is not a multiple of the element bitwidth");
259 
260  Location loc = op.getLoc();
261 
262  // If the reduced type is smaller than the native shuffle size, extend it,
263  // perform the shuffles, and extract at the end.
264  auto extendedVecTy = VectorType::get(
265  static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
266  Value extendedInput = op.getValue();
267  if (vecBitwidth < shuffleBitwidth) {
268  auto zero = rewriter.create<arith::ConstantOp>(
269  loc, rewriter.getZeroAttr(extendedVecTy));
270  extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
271  loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
272  }
273 
274  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
275  auto shuffleVecType = VectorType::get(1, shuffleIntType);
276 
277  auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
278  auto asIntVec =
279  rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
280  return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
281  };
282  auto unpackFn = [loc, &rewriter, shuffleVecType,
283  extendedVecTy](Value packedVal) -> Value {
284  auto asIntVec =
285  rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
286  return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
287  };
288 
289  Value res =
290  createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
291  subgroupSize, packFn, unpackFn);
292 
293  if (vecBitwidth < shuffleBitwidth) {
294  res = rewriter.create<vector::ExtractStridedSliceOp>(
295  loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
296  /*strides=*/1);
297  }
298 
299  rewriter.replaceOp(op, res);
300  return success();
301  }
302 
303 private:
304  unsigned subgroupSize = 0;
305  unsigned shuffleBitwidth = 0;
306 };
307 } // namespace
308 
310  RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
311  PatternBenefit benefit) {
312  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
313  maxShuffleBitwidth, benefit);
314  patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
315 }
316 
318  RewritePatternSet &patterns, unsigned subgroupSize,
319  unsigned shuffleBitwidth, PatternBenefit benefit) {
320  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
321  patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
322 }
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2544
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:119
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePattenrs(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362