MLIR  20.0.0git
SubgroupReduceLowering.cpp
Go to the documentation of this file.
1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements gradual lowering of `gpu.subgroup_reduce` ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
24 #include <cassert>
25 #include <cstdint>
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Example, assumes `maxShuffleBitwidth` equal to 32:
32 /// ```
33 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
34 /// ==>
35 /// %v0 = arith.constant dense<0.0> : vector<3xf16>
36 /// %e0 = vector.extract_strided_slice %x
37 /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
38 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
39 /// %v1 = vector.insert_strided_slice %r0, %v0
40 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
41 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
42 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
43 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
44 /// ```
45 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
46  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
47  PatternBenefit benefit)
48  : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
49  }
50 
51  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
52  PatternRewriter &rewriter) const override {
53  auto vecTy = dyn_cast<VectorType>(op.getType());
54  if (!vecTy || vecTy.getNumElements() < 2)
55  return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
56 
57  assert(vecTy.getRank() == 1 && "Unexpected vector type");
58  assert(!vecTy.isScalable() && "Unexpected vector type");
59 
60  Type elemTy = vecTy.getElementType();
61  unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
62  if (elemBitwidth >= maxShuffleBitwidth)
63  return rewriter.notifyMatchFailure(
64  op, llvm::formatv("element type too large ({0}), cannot break down "
65  "into vectors of bitwidth {1} or less",
66  elemBitwidth, maxShuffleBitwidth));
67 
68  unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
69  assert(elementsPerShuffle >= 1);
70 
71  unsigned numNewReductions =
72  llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
73  assert(numNewReductions >= 1);
74  if (numNewReductions == 1)
75  return rewriter.notifyMatchFailure(op, "nothing to break down");
76 
77  Location loc = op.getLoc();
78  Value res =
79  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
80 
81  for (unsigned i = 0; i != numNewReductions; ++i) {
82  int64_t startIdx = i * elementsPerShuffle;
83  int64_t endIdx =
84  std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
85  int64_t numElems = endIdx - startIdx;
86 
87  Value extracted;
88  if (numElems == 1) {
89  extracted =
90  rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
91  } else {
92  extracted = rewriter.create<vector::ExtractStridedSliceOp>(
93  loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
94  /*strides=*/1);
95  }
96 
97  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
98  loc, extracted, op.getOp(), op.getUniform());
99  if (numElems == 1) {
100  res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
101  continue;
102  }
103 
104  res = rewriter.create<vector::InsertStridedSliceOp>(
105  loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
106  }
107 
108  rewriter.replaceOp(op, res);
109  return success();
110  }
111 
112 private:
113  unsigned maxShuffleBitwidth = 0;
114 };
115 
116 /// Example:
117 /// ```
118 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
119 /// ==>
120 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
121 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
122 /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
123 /// ```
124 struct ScalarizeSingleElementReduce final
125  : OpRewritePattern<gpu::SubgroupReduceOp> {
127 
128  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
129  PatternRewriter &rewriter) const override {
130  auto vecTy = dyn_cast<VectorType>(op.getType());
131  if (!vecTy || vecTy.getNumElements() != 1)
132  return rewriter.notifyMatchFailure(op, "not a single-element reduction");
133 
134  assert(vecTy.getRank() == 1 && "Unexpected vector type");
135  assert(!vecTy.isScalable() && "Unexpected vector type");
136  Location loc = op.getLoc();
137  Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
138  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
139  loc, extracted, op.getOp(), op.getUniform());
140  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
141  return success();
142  }
143 };
144 
145 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
146 /// and `unpackFn` to convert to the native shuffle type and to the reduction
147 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
148 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
149 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
150 /// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
151 static Value createSubgroupShuffleReduction(
152  OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
153  unsigned subgroupSize, function_ref<Value(Value)> packFn,
154  function_ref<Value(Value)> unpackFn) {
155  assert(llvm::isPowerOf2_32(subgroupSize));
156  // Lane value always stays in the original type. We use it to perform arith
157  // reductions.
158  Value laneVal = input;
159  // Parallel reduction using butterfly shuffles.
160  for (unsigned i = 1; i < subgroupSize; i <<= 1) {
161  Value shuffled = builder
162  .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
163  /*width=*/subgroupSize,
164  /*mode=*/gpu::ShuffleMode::XOR)
165  .getShuffleResult();
166  laneVal = vector::makeArithReduction(builder, loc,
168  laneVal, unpackFn(shuffled));
169  assert(laneVal.getType() == input.getType());
170  }
171 
172  return laneVal;
173 }
174 
175 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
176 struct ScalarSubgroupReduceToShuffles final
177  : OpRewritePattern<gpu::SubgroupReduceOp> {
178  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
179  unsigned shuffleBitwidth,
180  PatternBenefit benefit)
181  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
182  shuffleBitwidth(shuffleBitwidth) {}
183 
184  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
185  PatternRewriter &rewriter) const override {
186  Type valueTy = op.getType();
187  unsigned elemBitwidth =
189  if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
190  return rewriter.notifyMatchFailure(
191  op, "value type is not a compatible scalar");
192 
193  Location loc = op.getLoc();
194  // Since this is already a native shuffle scalar, no packing is necessary.
195  if (elemBitwidth == shuffleBitwidth) {
196  auto identityFn = [](Value v) { return v; };
197  rewriter.replaceOp(op, createSubgroupShuffleReduction(
198  rewriter, loc, op.getValue(), op.getOp(),
199  subgroupSize, identityFn, identityFn));
200  return success();
201  }
202 
203  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
204  auto equivIntType = rewriter.getIntegerType(elemBitwidth);
205  auto packFn = [loc, &rewriter, equivIntType,
206  shuffleIntType](Value unpackedVal) -> Value {
207  auto asInt =
208  rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
209  return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
210  };
211  auto unpackFn = [loc, &rewriter, equivIntType,
212  valueTy](Value packedVal) -> Value {
213  auto asInt =
214  rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
215  return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
216  };
217 
218  rewriter.replaceOp(op, createSubgroupShuffleReduction(
219  rewriter, loc, op.getValue(), op.getOp(),
220  subgroupSize, packFn, unpackFn));
221  return success();
222  }
223 
224 private:
225  unsigned subgroupSize = 0;
226  unsigned shuffleBitwidth = 0;
227 };
228 
229 /// Lowers vector gpu subgroup reductions to a series of shuffles.
230 struct VectorSubgroupReduceToShuffles final
231  : OpRewritePattern<gpu::SubgroupReduceOp> {
232  VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
233  unsigned shuffleBitwidth,
234  PatternBenefit benefit)
235  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
236  shuffleBitwidth(shuffleBitwidth) {}
237 
238  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
239  PatternRewriter &rewriter) const override {
240  auto vecTy = dyn_cast<VectorType>(op.getType());
241  if (!vecTy)
242  return rewriter.notifyMatchFailure(op, "value type is not a vector");
243 
244  unsigned vecBitwidth =
245  vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
246  if (vecBitwidth > shuffleBitwidth)
247  return rewriter.notifyMatchFailure(
248  op,
249  llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
250  "to shuffles of size {1}",
251  vecBitwidth, shuffleBitwidth));
252 
253  unsigned elementsPerShuffle =
254  shuffleBitwidth / vecTy.getElementTypeBitWidth();
255  if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
256  return rewriter.notifyMatchFailure(
257  op, "shuffle bitwidth is not a multiple of the element bitwidth");
258 
259  Location loc = op.getLoc();
260 
261  // If the reduced type is smaller than the native shuffle size, extend it,
262  // perform the shuffles, and extract at the end.
263  auto extendedVecTy = VectorType::get(
264  static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
265  Value extendedInput = op.getValue();
266  if (vecBitwidth < shuffleBitwidth) {
267  auto zero = rewriter.create<arith::ConstantOp>(
268  loc, rewriter.getZeroAttr(extendedVecTy));
269  extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
270  loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
271  }
272 
273  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
274  auto shuffleVecType = VectorType::get(1, shuffleIntType);
275 
276  auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
277  auto asIntVec =
278  rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
279  return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
280  };
281  auto unpackFn = [loc, &rewriter, shuffleVecType,
282  extendedVecTy](Value packedVal) -> Value {
283  auto asIntVec =
284  rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
285  return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
286  };
287 
288  Value res =
289  createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
290  subgroupSize, packFn, unpackFn);
291 
292  if (vecBitwidth < shuffleBitwidth) {
293  res = rewriter.create<vector::ExtractStridedSliceOp>(
294  loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
295  /*strides=*/1);
296  }
297 
298  rewriter.replaceOp(op, res);
299  return success();
300  }
301 
302 private:
303  unsigned subgroupSize = 0;
304  unsigned shuffleBitwidth = 0;
305 };
306 } // namespace
307 
309  RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
310  PatternBenefit benefit) {
311  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
312  maxShuffleBitwidth, benefit);
313  patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
314 }
315 
317  RewritePatternSet &patterns, unsigned subgroupSize,
318  unsigned shuffleBitwidth, PatternBenefit benefit) {
319  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
320  patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
321 }
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2672
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePattenrs(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362