MLIR  21.0.0git
SubgroupReduceLowering.cpp
Go to the documentation of this file.
1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements gradual lowering of `gpu.subgroup_reduce` ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MathExtras.h"
28 #include <cassert>
29 #include <cstdint>
30 
31 using namespace mlir;
32 
33 namespace {
34 
35 /// Example, assumes `maxShuffleBitwidth` equal to 32:
36 /// ```
37 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
38 /// ==>
39 /// %v0 = arith.constant dense<0.0> : vector<3xf16>
40 /// %e0 = vector.extract_strided_slice %x
41 /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
42 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
43 /// %v1 = vector.insert_strided_slice %r0, %v0
44 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
45 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
46 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
47 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
48 /// ```
49 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
50  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
51  PatternBenefit benefit)
52  : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
53  }
54 
55  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
56  PatternRewriter &rewriter) const override {
57  auto vecTy = dyn_cast<VectorType>(op.getType());
58  if (!vecTy || vecTy.getNumElements() < 2)
59  return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
60 
61  assert(vecTy.getRank() == 1 && "Unexpected vector type");
62  assert(!vecTy.isScalable() && "Unexpected vector type");
63 
64  Type elemTy = vecTy.getElementType();
65  unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
66  if (elemBitwidth >= maxShuffleBitwidth)
67  return rewriter.notifyMatchFailure(
68  op, llvm::formatv("element type too large ({0}), cannot break down "
69  "into vectors of bitwidth {1} or less",
70  elemBitwidth, maxShuffleBitwidth));
71 
72  unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
73  assert(elementsPerShuffle >= 1);
74 
75  unsigned numNewReductions =
76  llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
77  assert(numNewReductions >= 1);
78  if (numNewReductions == 1)
79  return rewriter.notifyMatchFailure(op, "nothing to break down");
80 
81  Location loc = op.getLoc();
82  Value res =
83  rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
84 
85  for (unsigned i = 0; i != numNewReductions; ++i) {
86  int64_t startIdx = i * elementsPerShuffle;
87  int64_t endIdx =
88  std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
89  int64_t numElems = endIdx - startIdx;
90 
91  Value extracted;
92  if (numElems == 1) {
93  extracted =
94  rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
95  } else {
96  extracted = rewriter.create<vector::ExtractStridedSliceOp>(
97  loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
98  /*strides=*/1);
99  }
100 
101  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
102  loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
103  op.getClusterStride());
104  if (numElems == 1) {
105  res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
106  continue;
107  }
108 
109  res = rewriter.create<vector::InsertStridedSliceOp>(
110  loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
111  }
112 
113  rewriter.replaceOp(op, res);
114  return success();
115  }
116 
117 private:
118  unsigned maxShuffleBitwidth = 0;
119 };
120 
121 /// Example:
122 /// ```
123 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
124 /// ==>
125 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
126 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
127 /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
128 /// ```
129 struct ScalarizeSingleElementReduce final
130  : OpRewritePattern<gpu::SubgroupReduceOp> {
132 
133  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
134  PatternRewriter &rewriter) const override {
135  auto vecTy = dyn_cast<VectorType>(op.getType());
136  if (!vecTy || vecTy.getNumElements() != 1)
137  return rewriter.notifyMatchFailure(op, "not a single-element reduction");
138 
139  assert(vecTy.getRank() == 1 && "Unexpected vector type");
140  assert(!vecTy.isScalable() && "Unexpected vector type");
141  Location loc = op.getLoc();
142  Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
143  Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
144  loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
145  op.getClusterStride());
146  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
147  return success();
148  }
149 };
150 
151 struct ClusterInfo {
152  unsigned clusterStride;
153  unsigned clusterSize;
154  unsigned subgroupSize;
155 };
156 
157 static FailureOr<ClusterInfo>
158 getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
159  assert(llvm::isPowerOf2_32(subgroupSize));
160 
161  std::optional<uint32_t> clusterSize = op.getClusterSize();
162  assert(!clusterSize ||
163  llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
164  if (clusterSize && *clusterSize > subgroupSize)
165  return op.emitOpError()
166  << "cluster size " << *clusterSize
167  << " is greater than subgroup size " << subgroupSize;
168  unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
169 
170  auto clusterStride = op.getClusterStride();
171  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
172  if (clusterStride >= subgroupSize)
173  return op.emitOpError()
174  << "cluster stride " << clusterStride
175  << " is not less than subgroup size " << subgroupSize;
176 
177  return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
178 }
179 
180 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
181 /// and `unpackFn` to convert to the native shuffle type and to the reduction
182 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
183 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
184 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
185 /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
186 /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
187 /// lanes within a cluster, reducing all lanes in each cluster in parallel.
188 Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
189  Value input, gpu::AllReduceOperation mode,
190  const ClusterInfo &ci,
191  function_ref<Value(Value)> packFn,
192  function_ref<Value(Value)> unpackFn) {
193  // Lane value always stays in the original type. We use it to perform arith
194  // reductions.
195  Value laneVal = input;
196  // Parallel reduction using butterfly shuffles.
197  for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
198  i <<= 1) {
199  Value shuffled = builder
200  .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
201  /*width=*/ci.subgroupSize,
202  /*mode=*/gpu::ShuffleMode::XOR)
203  .getShuffleResult();
204  laneVal = vector::makeArithReduction(builder, loc,
206  laneVal, unpackFn(shuffled));
207  assert(laneVal.getType() == input.getType());
208  }
209 
210  return laneVal;
211 }
212 
213 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
214 struct ScalarSubgroupReduceToShuffles final
215  : OpRewritePattern<gpu::SubgroupReduceOp> {
216  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
217  unsigned shuffleBitwidth, bool matchClustered,
218  PatternBenefit benefit)
219  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
220  shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
221 
222  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
223  PatternRewriter &rewriter) const override {
224  if (op.getClusterSize().has_value() != matchClustered) {
225  return rewriter.notifyMatchFailure(
226  op, llvm::formatv("op is {0}clustered but pattern is configured to "
227  "only match {1}clustered ops",
228  matchClustered ? "non-" : "",
229  matchClustered ? "" : "non-"));
230  }
231 
232  auto ci = getAndValidateClusterInfo(op, subgroupSize);
233  if (failed(ci))
234  return failure();
235 
236  Type valueTy = op.getType();
237  unsigned elemBitwidth =
239  if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
240  return rewriter.notifyMatchFailure(
241  op, "value type is not a compatible scalar");
242 
243  Location loc = op.getLoc();
244  // Since this is already a native shuffle scalar, no packing is necessary.
245  if (elemBitwidth == shuffleBitwidth) {
246  auto identityFn = [](Value v) { return v; };
247  rewriter.replaceOp(op, createSubgroupShuffleReduction(
248  rewriter, loc, op.getValue(), op.getOp(), *ci,
249  identityFn, identityFn));
250  return success();
251  }
252 
253  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
254  auto equivIntType = rewriter.getIntegerType(elemBitwidth);
255  auto packFn = [loc, &rewriter, equivIntType,
256  shuffleIntType](Value unpackedVal) -> Value {
257  auto asInt =
258  rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
259  return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
260  };
261  auto unpackFn = [loc, &rewriter, equivIntType,
262  valueTy](Value packedVal) -> Value {
263  auto asInt =
264  rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
265  return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
266  };
267 
268  rewriter.replaceOp(
269  op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270  op.getOp(), *ci, packFn, unpackFn));
271  return success();
272  }
273 
274 private:
275  unsigned subgroupSize = 0;
276  unsigned shuffleBitwidth = 0;
277  bool matchClustered = false;
278 };
279 
280 /// Lowers vector gpu subgroup reductions to a series of shuffles.
281 struct VectorSubgroupReduceToShuffles final
282  : OpRewritePattern<gpu::SubgroupReduceOp> {
283  VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
284  unsigned shuffleBitwidth, bool matchClustered,
285  PatternBenefit benefit)
286  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
287  shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
288 
289  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
290  PatternRewriter &rewriter) const override {
291  if (op.getClusterSize().has_value() != matchClustered) {
292  return rewriter.notifyMatchFailure(
293  op, llvm::formatv("op is {0}clustered but pattern is configured to "
294  "only match {1}clustered ops",
295  matchClustered ? "non-" : "",
296  matchClustered ? "" : "non-"));
297  }
298 
299  auto ci = getAndValidateClusterInfo(op, subgroupSize);
300  if (failed(ci))
301  return failure();
302 
303  auto vecTy = dyn_cast<VectorType>(op.getType());
304  if (!vecTy)
305  return rewriter.notifyMatchFailure(op, "value type is not a vector");
306 
307  unsigned vecBitwidth =
308  vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309  if (vecBitwidth > shuffleBitwidth)
310  return rewriter.notifyMatchFailure(
311  op,
312  llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
313  "to shuffles of size {1}",
314  vecBitwidth, shuffleBitwidth));
315 
316  unsigned elementsPerShuffle =
317  shuffleBitwidth / vecTy.getElementTypeBitWidth();
318  if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
319  return rewriter.notifyMatchFailure(
320  op, "shuffle bitwidth is not a multiple of the element bitwidth");
321 
322  Location loc = op.getLoc();
323 
324  // If the reduced type is smaller than the native shuffle size, extend it,
325  // perform the shuffles, and extract at the end.
326  auto extendedVecTy = VectorType::get(
327  static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
328  Value extendedInput = op.getValue();
329  if (vecBitwidth < shuffleBitwidth) {
330  auto zero = rewriter.create<arith::ConstantOp>(
331  loc, rewriter.getZeroAttr(extendedVecTy));
332  extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
333  loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
334  }
335 
336  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
337  auto shuffleVecType = VectorType::get(1, shuffleIntType);
338 
339  auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
340  auto asIntVec =
341  rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
342  return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
343  };
344  auto unpackFn = [loc, &rewriter, shuffleVecType,
345  extendedVecTy](Value packedVal) -> Value {
346  auto asIntVec =
347  rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
348  return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
349  };
350 
351  Value res = createSubgroupShuffleReduction(
352  rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
353 
354  if (vecBitwidth < shuffleBitwidth) {
355  res = rewriter.create<vector::ExtractStridedSliceOp>(
356  loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
357  /*strides=*/1);
358  }
359 
360  rewriter.replaceOp(op, res);
361  return success();
362  }
363 
364 private:
365  unsigned subgroupSize = 0;
366  unsigned shuffleBitwidth = 0;
367  bool matchClustered = false;
368 };
369 
370 static FailureOr<Value>
371 createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372  Value input, gpu::AllReduceOperation mode,
373  const ClusterInfo &ci, amdgpu::Chipset chipset) {
374  Location loc = op.getLoc();
375  Value dpp;
376  Value res = input;
377  constexpr int allRows = 0xf;
378  constexpr int allBanks = 0xf;
379  const bool boundCtrl = true;
380  if (ci.clusterSize >= 2) {
381  // Perform reduction between all lanes N <-> N+1.
382  dpp = rewriter.create<amdgpu::DPPOp>(
383  loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
384  rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
385  res = vector::makeArithReduction(rewriter, loc,
386  gpu::convertReductionKind(mode), res, dpp);
387  }
388 
389  if (ci.clusterSize >= 4) {
390  // Perform reduction between all lanes N <-> N+2.
391  dpp = rewriter.create<amdgpu::DPPOp>(
392  loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393  rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
394  res = vector::makeArithReduction(rewriter, loc,
395  gpu::convertReductionKind(mode), res, dpp);
396  }
397  if (ci.clusterSize >= 8) {
398  // Perform reduction between all lanes N <-> 7-N,
399  // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400  dpp = rewriter.create<amdgpu::DPPOp>(
401  loc, res.getType(), res, res, amdgpu::DPPPerm::row_half_mirror,
402  rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
403  res = vector::makeArithReduction(rewriter, loc,
404  gpu::convertReductionKind(mode), res, dpp);
405  }
406  if (ci.clusterSize >= 16) {
407  // Perform reduction between all lanes N <-> 15-N,
408  // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
409  dpp = rewriter.create<amdgpu::DPPOp>(
410  loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
411  rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
412  res = vector::makeArithReduction(rewriter, loc,
413  gpu::convertReductionKind(mode), res, dpp);
414  }
415  if (ci.clusterSize >= 32) {
416  if (chipset.majorVersion <= 9) {
417  // Broadcast last value from each row to next row.
418  // Use row mask to avoid polluting rows 1 and 3.
419  dpp = rewriter.create<amdgpu::DPPOp>(
420  loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15,
421  rewriter.getUnitAttr(), 0xa, allBanks,
422  /*bound_ctrl*/ false);
424  rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
425  } else if (chipset.majorVersion <= 12) {
426  // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
427  Value uint32Max = rewriter.create<arith::ConstantOp>(
428  loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
429  dpp = rewriter.create<ROCDL::PermlaneX16Op>(loc, res.getType(), res, res,
430  uint32Max, uint32Max,
431  /*fi=*/true,
432  /*bound_ctrl=*/false);
434  rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
435  if (ci.subgroupSize == 32) {
436  Value lane0 = rewriter.create<arith::ConstantOp>(
437  loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
438  res =
439  rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
440  }
441  } else {
442  return rewriter.notifyMatchFailure(
443  op, "Subgroup reduce lowering to DPP not currently supported for "
444  "this device.");
445  }
446  }
447  if (ci.clusterSize >= 64) {
448  if (chipset.majorVersion <= 9) {
449  // Broadcast 31st lane value to rows 2 and 3.
450  // Use row mask to avoid polluting rows 0 and 1.
451  dpp = rewriter.create<amdgpu::DPPOp>(
452  loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31,
453  rewriter.getUnitAttr(), 0xc, allBanks,
454  /*bound_ctrl*/ false);
455 
456  } else if (chipset.majorVersion <= 12) {
457  // Assume reduction across 32 lanes has been done.
458  // Perform final reduction manually by summing values in lane 0 and
459  // lane 32.
460  Value lane0 = rewriter.create<arith::ConstantOp>(
461  loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
462  Value lane32 = rewriter.create<arith::ConstantOp>(
463  loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32));
464  dpp = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane32);
465  res = rewriter.create<ROCDL::ReadlaneOp>(loc, res.getType(), res, lane0);
466  } else {
467  return rewriter.notifyMatchFailure(
468  op, "Subgroup reduce lowering to DPP not currently supported for "
469  "this device.");
470  }
471  res = vector::makeArithReduction(rewriter, loc,
472  gpu::convertReductionKind(mode), res, dpp);
473  }
474  assert(res.getType() == input.getType());
475  return res;
476 }
477 
478 /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
479 /// ops over scalar types. Assumes that the subgroup has
480 /// `subgroupSize` lanes. Applicable only to AMD GPUs.
481 struct ScalarSubgroupReduceToDPP final
482  : OpRewritePattern<gpu::SubgroupReduceOp> {
483  ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
484  bool matchClustered, amdgpu::Chipset chipset,
485  PatternBenefit benefit)
486  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
487  matchClustered(matchClustered), chipset(chipset) {}
488 
489  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
490  PatternRewriter &rewriter) const override {
491  if (op.getClusterSize().has_value() != matchClustered) {
492  return rewriter.notifyMatchFailure(
493  op, llvm::formatv("op is {0}clustered but pattern is configured to "
494  "only match {1}clustered ops",
495  matchClustered ? "non-" : "",
496  matchClustered ? "" : "non-"));
497  }
498  auto ci = getAndValidateClusterInfo(op, subgroupSize);
499  if (failed(ci))
500  return failure();
501 
502  if (ci->clusterStride != 1)
503  return rewriter.notifyMatchFailure(
504  op, "Subgroup reductions using DPP are currently only available for "
505  "clusters of contiguous lanes.");
506 
507  Type valueTy = op.getType();
508  if (!valueTy.isIntOrFloat())
509  return rewriter.notifyMatchFailure(
510  op, "Value type is not a compatible scalar.");
511 
512  FailureOr<Value> dpp = createSubgroupDPPReduction(
513  rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
514  if (failed(dpp))
515  return failure();
516 
517  rewriter.replaceOp(op, dpp.value());
518  return success();
519  }
520 
521 private:
522  unsigned subgroupSize = 0;
523  bool matchClustered = false;
524  amdgpu::Chipset chipset;
525 };
526 } // namespace
527 
529  RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
530  PatternBenefit benefit) {
531  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
532  maxShuffleBitwidth, benefit);
533  patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
534 }
535 
538  PatternBenefit benefit) {
539  patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
540  /*matchClustered=*/false, chipset,
541  benefit);
542 }
543 
546  PatternBenefit benefit) {
547  patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
548  /*matchClustered=*/true, chipset,
549  benefit);
550 }
551 
554  unsigned shuffleBitwidth, PatternBenefit benefit) {
555  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
556  patterns.getContext(), subgroupSize, shuffleBitwidth,
557  /*matchClustered=*/false, benefit);
558 }
559 
562  unsigned shuffleBitwidth, PatternBenefit benefit) {
563  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
564  patterns.getContext(), subgroupSize, shuffleBitwidth,
565  /*matchClustered=*/true, benefit);
566 }
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2828
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
constexpr unsigned subgroupSize
HW dependent constants.
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:272
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into amdgpu.dpp ops over scalar types.
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToDPPPatterns that only matches gpu....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22