MLIR  22.0.0git
SubgroupReduceLowering.cpp
Go to the documentation of this file.
1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements gradual lowering of `gpu.subgroup_reduce` ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/MathExtras.h"
27 #include <cassert>
28 #include <cstdint>
29 
30 using namespace mlir;
31 
32 namespace {
33 
34 /// Example, assumes `maxShuffleBitwidth` equal to 32:
35 /// ```
36 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
37 /// ==>
38 /// %v0 = arith.constant dense<0.0> : vector<3xf16>
39 /// %e0 = vector.extract_strided_slice %x
40 /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
41 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
42 /// %v1 = vector.insert_strided_slice %r0, %v0
43 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
44 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
45 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
46 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
47 /// ```
48 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
49  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
50  PatternBenefit benefit)
51  : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
52  }
53 
54  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
55  PatternRewriter &rewriter) const override {
56  auto vecTy = dyn_cast<VectorType>(op.getType());
57  if (!vecTy || vecTy.getNumElements() < 2)
58  return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
59 
60  assert(vecTy.getRank() == 1 && "Unexpected vector type");
61  assert(!vecTy.isScalable() && "Unexpected vector type");
62 
63  Type elemTy = vecTy.getElementType();
64  unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
65  if (elemBitwidth >= maxShuffleBitwidth)
66  return rewriter.notifyMatchFailure(
67  op, llvm::formatv("element type too large ({0}), cannot break down "
68  "into vectors of bitwidth {1} or less",
69  elemBitwidth, maxShuffleBitwidth));
70 
71  unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
72  assert(elementsPerShuffle >= 1);
73 
74  unsigned numNewReductions =
75  llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
76  assert(numNewReductions >= 1);
77  if (numNewReductions == 1)
78  return rewriter.notifyMatchFailure(op, "nothing to break down");
79 
80  Location loc = op.getLoc();
81  Value res =
82  arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(vecTy));
83 
84  for (unsigned i = 0; i != numNewReductions; ++i) {
85  int64_t startIdx = i * elementsPerShuffle;
86  int64_t endIdx =
87  std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
88  int64_t numElems = endIdx - startIdx;
89 
90  Value extracted;
91  if (numElems == 1) {
92  extracted =
93  vector::ExtractOp::create(rewriter, loc, op.getValue(), startIdx);
94  } else {
95  extracted = vector::ExtractStridedSliceOp::create(
96  rewriter, loc, op.getValue(), /*offsets=*/startIdx,
97  /*sizes=*/numElems,
98  /*strides=*/1);
99  }
100 
101  Value reduce = gpu::SubgroupReduceOp::create(
102  rewriter, loc, extracted, op.getOp(), op.getUniform(),
103  op.getClusterSize(), op.getClusterStride());
104  if (numElems == 1) {
105  res = vector::InsertOp::create(rewriter, loc, reduce, res, startIdx);
106  continue;
107  }
108 
109  res = vector::InsertStridedSliceOp::create(
110  rewriter, loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
111  }
112 
113  rewriter.replaceOp(op, res);
114  return success();
115  }
116 
117 private:
118  unsigned maxShuffleBitwidth = 0;
119 };
120 
121 /// Example:
122 /// ```
123 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
124 /// ==>
125 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
126 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
127 /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
128 /// ```
129 struct ScalarizeSingleElementReduce final
130  : OpRewritePattern<gpu::SubgroupReduceOp> {
132 
133  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
134  PatternRewriter &rewriter) const override {
135  auto vecTy = dyn_cast<VectorType>(op.getType());
136  if (!vecTy || vecTy.getNumElements() != 1)
137  return rewriter.notifyMatchFailure(op, "not a single-element reduction");
138 
139  assert(vecTy.getRank() == 1 && "Unexpected vector type");
140  assert(!vecTy.isScalable() && "Unexpected vector type");
141  Location loc = op.getLoc();
142  Value extracted =
143  vector::ExtractOp::create(rewriter, loc, op.getValue(), 0);
144  Value reduce = gpu::SubgroupReduceOp::create(
145  rewriter, loc, extracted, op.getOp(), op.getUniform(),
146  op.getClusterSize(), op.getClusterStride());
147  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
148  return success();
149  }
150 };
151 
152 struct ClusterInfo {
153  unsigned clusterStride;
154  unsigned clusterSize;
155  unsigned subgroupSize;
156 };
157 
158 static FailureOr<ClusterInfo>
159 getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
160  assert(llvm::isPowerOf2_32(subgroupSize));
161 
162  std::optional<uint32_t> clusterSize = op.getClusterSize();
163  assert(!clusterSize ||
164  llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
165  if (clusterSize && *clusterSize > subgroupSize)
166  return op.emitOpError()
167  << "cluster size " << *clusterSize
168  << " is greater than subgroup size " << subgroupSize;
169  unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
170 
171  auto clusterStride = op.getClusterStride();
172  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
173  if (clusterStride >= subgroupSize)
174  return op.emitOpError()
175  << "cluster stride " << clusterStride
176  << " is not less than subgroup size " << subgroupSize;
177 
178  return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
179 }
180 
181 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
182 /// and `unpackFn` to convert to the native shuffle type and to the reduction
183 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
184 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
185 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
186 /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
187 /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
188 /// lanes within a cluster, reducing all lanes in each cluster in parallel.
189 Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
190  Value input, gpu::AllReduceOperation mode,
191  const ClusterInfo &ci,
192  function_ref<Value(Value)> packFn,
193  function_ref<Value(Value)> unpackFn) {
194  // Lane value always stays in the original type. We use it to perform arith
195  // reductions.
196  Value laneVal = input;
197  // Parallel reduction using butterfly shuffles.
198  for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
199  i <<= 1) {
200  Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i,
201  /*width=*/ci.subgroupSize,
202  /*mode=*/gpu::ShuffleMode::XOR)
203  .getShuffleResult();
204  laneVal = vector::makeArithReduction(builder, loc,
206  laneVal, unpackFn(shuffled));
207  assert(laneVal.getType() == input.getType());
208  }
209 
210  return laneVal;
211 }
212 
213 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
214 struct ScalarSubgroupReduceToShuffles final
215  : OpRewritePattern<gpu::SubgroupReduceOp> {
216  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
217  unsigned shuffleBitwidth, bool matchClustered,
218  PatternBenefit benefit)
219  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
220  shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
221 
222  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
223  PatternRewriter &rewriter) const override {
224  if (op.getClusterSize().has_value() != matchClustered) {
225  return rewriter.notifyMatchFailure(
226  op, llvm::formatv("op is {0}clustered but pattern is configured to "
227  "only match {1}clustered ops",
228  matchClustered ? "non-" : "",
229  matchClustered ? "" : "non-"));
230  }
231 
232  auto ci = getAndValidateClusterInfo(op, subgroupSize);
233  if (failed(ci))
234  return failure();
235 
236  Type valueTy = op.getType();
237  unsigned elemBitwidth =
239  if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
240  return rewriter.notifyMatchFailure(
241  op, "value type is not a compatible scalar");
242 
243  Location loc = op.getLoc();
244  // Since this is already a native shuffle scalar, no packing is necessary.
245  if (elemBitwidth == shuffleBitwidth) {
246  auto identityFn = [](Value v) { return v; };
247  rewriter.replaceOp(op, createSubgroupShuffleReduction(
248  rewriter, loc, op.getValue(), op.getOp(), *ci,
249  identityFn, identityFn));
250  return success();
251  }
252 
253  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
254  auto equivIntType = rewriter.getIntegerType(elemBitwidth);
255  auto packFn = [loc, &rewriter, equivIntType,
256  shuffleIntType](Value unpackedVal) -> Value {
257  auto asInt =
258  arith::BitcastOp::create(rewriter, loc, equivIntType, unpackedVal);
259  return arith::ExtUIOp::create(rewriter, loc, shuffleIntType, asInt);
260  };
261  auto unpackFn = [loc, &rewriter, equivIntType,
262  valueTy](Value packedVal) -> Value {
263  auto asInt =
264  arith::TruncIOp::create(rewriter, loc, equivIntType, packedVal);
265  return arith::BitcastOp::create(rewriter, loc, valueTy, asInt);
266  };
267 
268  rewriter.replaceOp(
269  op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270  op.getOp(), *ci, packFn, unpackFn));
271  return success();
272  }
273 
274 private:
275  unsigned subgroupSize = 0;
276  unsigned shuffleBitwidth = 0;
277  bool matchClustered = false;
278 };
279 
280 /// Lowers vector gpu subgroup reductions to a series of shuffles.
281 struct VectorSubgroupReduceToShuffles final
282  : OpRewritePattern<gpu::SubgroupReduceOp> {
283  VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
284  unsigned shuffleBitwidth, bool matchClustered,
285  PatternBenefit benefit)
286  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
287  shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
288 
289  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
290  PatternRewriter &rewriter) const override {
291  if (op.getClusterSize().has_value() != matchClustered) {
292  return rewriter.notifyMatchFailure(
293  op, llvm::formatv("op is {0}clustered but pattern is configured to "
294  "only match {1}clustered ops",
295  matchClustered ? "non-" : "",
296  matchClustered ? "" : "non-"));
297  }
298 
299  auto ci = getAndValidateClusterInfo(op, subgroupSize);
300  if (failed(ci))
301  return failure();
302 
303  auto vecTy = dyn_cast<VectorType>(op.getType());
304  if (!vecTy)
305  return rewriter.notifyMatchFailure(op, "value type is not a vector");
306 
307  unsigned vecBitwidth =
308  vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309  if (vecBitwidth > shuffleBitwidth)
310  return rewriter.notifyMatchFailure(
311  op,
312  llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
313  "to shuffles of size {1}",
314  vecBitwidth, shuffleBitwidth));
315 
316  unsigned elementsPerShuffle =
317  shuffleBitwidth / vecTy.getElementTypeBitWidth();
318  if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
319  return rewriter.notifyMatchFailure(
320  op, "shuffle bitwidth is not a multiple of the element bitwidth");
321 
322  Location loc = op.getLoc();
323 
324  // If the reduced type is smaller than the native shuffle size, extend it,
325  // perform the shuffles, and extract at the end.
326  auto extendedVecTy = VectorType::get(
327  static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
328  Value extendedInput = op.getValue();
329  if (vecBitwidth < shuffleBitwidth) {
330  auto zero = arith::ConstantOp::create(
331  rewriter, loc, rewriter.getZeroAttr(extendedVecTy));
332  extendedInput = vector::InsertStridedSliceOp::create(
333  rewriter, loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
334  }
335 
336  auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
337  auto shuffleVecType = VectorType::get(1, shuffleIntType);
338 
339  auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
340  auto asIntVec =
341  vector::BitCastOp::create(rewriter, loc, shuffleVecType, unpackedVal);
342  return vector::ExtractOp::create(rewriter, loc, asIntVec, 0);
343  };
344  auto unpackFn = [loc, &rewriter, shuffleVecType,
345  extendedVecTy](Value packedVal) -> Value {
346  auto asIntVec =
347  vector::BroadcastOp::create(rewriter, loc, shuffleVecType, packedVal);
348  return vector::BitCastOp::create(rewriter, loc, extendedVecTy, asIntVec);
349  };
350 
351  Value res = createSubgroupShuffleReduction(
352  rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
353 
354  if (vecBitwidth < shuffleBitwidth) {
355  res = vector::ExtractStridedSliceOp::create(
356  rewriter, loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
357  /*strides=*/1);
358  }
359 
360  rewriter.replaceOp(op, res);
361  return success();
362  }
363 
364 private:
365  unsigned subgroupSize = 0;
366  unsigned shuffleBitwidth = 0;
367  bool matchClustered = false;
368 };
369 
370 static FailureOr<Value>
371 createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372  Value input, gpu::AllReduceOperation mode,
373  const ClusterInfo &ci, amdgpu::Chipset chipset) {
374  Location loc = op.getLoc();
375  Value dpp;
376  Value res = input;
377  constexpr int allRows = 0xf;
378  constexpr int allBanks = 0xf;
379  const bool boundCtrl = true;
380  if (ci.clusterSize >= 2) {
381  // Perform reduction between all lanes N <-> N+1.
382  dpp = amdgpu::DPPOp::create(
383  rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
384  rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
385  res = vector::makeArithReduction(rewriter, loc,
386  gpu::convertReductionKind(mode), res, dpp);
387  }
388 
389  if (ci.clusterSize >= 4) {
390  // Perform reduction between all lanes N <-> N+2.
391  dpp = amdgpu::DPPOp::create(
392  rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393  rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
394  res = vector::makeArithReduction(rewriter, loc,
395  gpu::convertReductionKind(mode), res, dpp);
396  }
397  if (ci.clusterSize >= 8) {
398  // Perform reduction between all lanes N <-> 7-N,
399  // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400  dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
401  amdgpu::DPPPerm::row_half_mirror,
402  rewriter.getUnitAttr(), allRows, allBanks,
403  boundCtrl);
404  res = vector::makeArithReduction(rewriter, loc,
405  gpu::convertReductionKind(mode), res, dpp);
406  }
407  if (ci.clusterSize >= 16) {
408  // Perform reduction between all lanes N <-> 15-N,
409  // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
410  dpp = amdgpu::DPPOp::create(
411  rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
412  rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
413  res = vector::makeArithReduction(rewriter, loc,
414  gpu::convertReductionKind(mode), res, dpp);
415  }
416  if (ci.clusterSize >= 32) {
417  if (chipset.majorVersion <= 9) {
418  // Broadcast last value from each row to next row.
419  // Use row mask to avoid polluting rows 1 and 3.
420  dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
421  amdgpu::DPPPerm::row_bcast_15,
422  rewriter.getUnitAttr(), 0xa, allBanks,
423  /*bound_ctrl*/ false);
425  rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
426  } else if (chipset.majorVersion <= 12) {
427  // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
428  Value uint32Max = arith::ConstantOp::create(
429  rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
430  dpp = ROCDL::PermlaneX16Op::create(rewriter, loc, res.getType(), res, res,
431  uint32Max, uint32Max,
432  /*fi=*/true,
433  /*bound_ctrl=*/false);
435  rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
436  } else {
437  return rewriter.notifyMatchFailure(
438  op, "Subgroup reduce lowering to DPP not currently supported for "
439  "this device.");
440  }
441  if (ci.subgroupSize == 32) {
442  Value lane31 = arith::ConstantOp::create(
443  rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
444  res =
445  ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31);
446  }
447  }
448  if (ci.clusterSize >= 64) {
449  if (chipset.majorVersion <= 9) {
450  // Broadcast 31st lane value to rows 2 and 3.
451  dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
452  amdgpu::DPPPerm::row_bcast_31,
453  rewriter.getUnitAttr(), 0xf, allBanks,
454  /*bound_ctrl*/ true);
456  rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
457  // Obtain reduction from last rows, the previous rows are polluted.
458  Value lane63 = arith::ConstantOp::create(
459  rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
460  res =
461  ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63);
462 
463  } else if (chipset.majorVersion <= 12) {
464  // Assume reduction across 32 lanes has been done.
465  // Perform final reduction manually by summing values in lane 0 and
466  // lane 32.
467  Value lane31 = arith::ConstantOp::create(
468  rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
469  Value lane63 = arith::ConstantOp::create(
470  rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
471  lane31 =
472  ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31);
473  lane63 =
474  ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63);
476  rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
477  } else {
478  return rewriter.notifyMatchFailure(
479  op, "Subgroup reduce lowering to DPP not currently supported for "
480  "this device.");
481  }
482  }
483  assert(res.getType() == input.getType());
484  return res;
485 }
486 
487 /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
488 /// ops over scalar types. Assumes that the subgroup has
489 /// `subgroupSize` lanes. Applicable only to AMD GPUs.
490 struct ScalarSubgroupReduceToDPP final
491  : OpRewritePattern<gpu::SubgroupReduceOp> {
492  ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
493  bool matchClustered, amdgpu::Chipset chipset,
494  PatternBenefit benefit)
495  : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
496  matchClustered(matchClustered), chipset(chipset) {}
497 
498  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
499  PatternRewriter &rewriter) const override {
500  if (op.getClusterSize().has_value() != matchClustered) {
501  return rewriter.notifyMatchFailure(
502  op, llvm::formatv("op is {0}clustered but pattern is configured to "
503  "only match {1}clustered ops",
504  matchClustered ? "non-" : "",
505  matchClustered ? "" : "non-"));
506  }
507  auto ci = getAndValidateClusterInfo(op, subgroupSize);
508  if (failed(ci))
509  return failure();
510 
511  if (ci->clusterStride != 1)
512  return rewriter.notifyMatchFailure(
513  op, "Subgroup reductions using DPP are currently only available for "
514  "clusters of contiguous lanes.");
515 
516  Type valueTy = op.getType();
517  if (!valueTy.isIntOrFloat())
518  return rewriter.notifyMatchFailure(
519  op, "Value type is not a compatible scalar.");
520 
521  FailureOr<Value> dpp = createSubgroupDPPReduction(
522  rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
523  if (failed(dpp))
524  return failure();
525 
526  rewriter.replaceOp(op, dpp.value());
527  return success();
528  }
529 
530 private:
531  unsigned subgroupSize = 0;
532  bool matchClustered = false;
533  amdgpu::Chipset chipset;
534 };
535 } // namespace
536 
538  RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
539  PatternBenefit benefit) {
540  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
541  maxShuffleBitwidth, benefit);
542  patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
543 }
544 
547  PatternBenefit benefit) {
548  patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
549  /*matchClustered=*/false, chipset,
550  benefit);
551 }
552 
555  PatternBenefit benefit) {
556  patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
557  /*matchClustered=*/true, chipset,
558  benefit);
559 }
560 
563  unsigned shuffleBitwidth, PatternBenefit benefit) {
564  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
565  patterns.getContext(), subgroupSize, shuffleBitwidth,
566  /*matchClustered=*/false, benefit);
567 }
568 
571  unsigned shuffleBitwidth, PatternBenefit benefit) {
572  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
573  patterns.getContext(), subgroupSize, shuffleBitwidth,
574  /*matchClustered=*/true, benefit);
575 }
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2915
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:271
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition: Utils.cpp:18
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
constexpr unsigned subgroupSize
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into amdgpu.dpp ops over scalar types.
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToDPPPatterns that only matches gpu....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22