MLIR 22.0.0git
SubgroupReduceLowering.cpp
Go to the documentation of this file.
1//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements gradual lowering of `gpu.subgroup_reduce` ops.
10//
11//===----------------------------------------------------------------------===//
12
22#include "mlir/IR/Location.h"
25#include "llvm/Support/FormatVariadic.h"
26#include "llvm/Support/MathExtras.h"
27#include <cassert>
28#include <cstdint>
29
30using namespace mlir;
31
32namespace {
33
34/// Example, assumes `maxShuffleBitwidth` equal to 32:
35/// ```
36/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
37/// ==>
38/// %v0 = arith.constant dense<0.0> : vector<3xf16>
39/// %e0 = vector.extract_strided_slice %x
40/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
41/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
42/// %v1 = vector.insert_strided_slice %r0, %v0
43/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
44/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
45/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
46/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
47/// ```
48struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
49 BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
50 PatternBenefit benefit)
51 : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
52 }
53
54 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
55 PatternRewriter &rewriter) const override {
56 auto vecTy = dyn_cast<VectorType>(op.getType());
57 if (!vecTy || vecTy.getNumElements() < 2)
58 return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
59
60 assert(vecTy.getRank() == 1 && "Unexpected vector type");
61 assert(!vecTy.isScalable() && "Unexpected vector type");
62
63 Type elemTy = vecTy.getElementType();
64 unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
65 if (elemBitwidth >= maxShuffleBitwidth)
66 return rewriter.notifyMatchFailure(
67 op, llvm::formatv("element type too large ({0}), cannot break down "
68 "into vectors of bitwidth {1} or less",
69 elemBitwidth, maxShuffleBitwidth));
70
71 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
72 assert(elementsPerShuffle >= 1);
73
74 unsigned numNewReductions =
75 llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
76 assert(numNewReductions >= 1);
77 if (numNewReductions == 1)
78 return rewriter.notifyMatchFailure(op, "nothing to break down");
79
80 Location loc = op.getLoc();
81 Value res =
82 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(vecTy));
83
84 for (unsigned i = 0; i != numNewReductions; ++i) {
85 int64_t startIdx = i * elementsPerShuffle;
86 int64_t endIdx =
87 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
88 int64_t numElems = endIdx - startIdx;
89
90 Value extracted;
91 if (numElems == 1) {
92 extracted =
93 vector::ExtractOp::create(rewriter, loc, op.getValue(), startIdx);
94 } else {
95 extracted = vector::ExtractStridedSliceOp::create(
96 rewriter, loc, op.getValue(), /*offsets=*/startIdx,
97 /*sizes=*/numElems,
98 /*strides=*/1);
99 }
100
101 Value reduce = gpu::SubgroupReduceOp::create(
102 rewriter, loc, extracted, op.getOp(), op.getUniform(),
103 op.getClusterSize(), op.getClusterStride());
104 if (numElems == 1) {
105 res = vector::InsertOp::create(rewriter, loc, reduce, res, startIdx);
106 continue;
107 }
108
109 res = vector::InsertStridedSliceOp::create(
110 rewriter, loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
111 }
112
113 rewriter.replaceOp(op, res);
114 return success();
115 }
116
117private:
118 unsigned maxShuffleBitwidth = 0;
119};
120
121/// Example:
122/// ```
123/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
124/// ==>
125/// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
126/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
127/// %a = vector.broadcast %r0 : f32 to vector<1xf32>
128/// ```
129struct ScalarizeSingleElementReduce final
130 : OpRewritePattern<gpu::SubgroupReduceOp> {
132
133 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
134 PatternRewriter &rewriter) const override {
135 auto vecTy = dyn_cast<VectorType>(op.getType());
136 if (!vecTy || vecTy.getNumElements() != 1)
137 return rewriter.notifyMatchFailure(op, "not a single-element reduction");
138
139 assert(vecTy.getRank() == 1 && "Unexpected vector type");
140 assert(!vecTy.isScalable() && "Unexpected vector type");
141 Location loc = op.getLoc();
142 Value extracted =
143 vector::ExtractOp::create(rewriter, loc, op.getValue(), 0);
144 Value reduce = gpu::SubgroupReduceOp::create(
145 rewriter, loc, extracted, op.getOp(), op.getUniform(),
146 op.getClusterSize(), op.getClusterStride());
147 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
148 return success();
149 }
150};
151
152struct ClusterInfo {
153 unsigned clusterStride;
154 unsigned clusterSize;
155 unsigned subgroupSize;
156};
157
158static FailureOr<ClusterInfo>
159getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
160 assert(llvm::isPowerOf2_32(subgroupSize));
161
162 std::optional<uint32_t> clusterSize = op.getClusterSize();
163 assert(!clusterSize ||
164 llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
165 if (clusterSize && *clusterSize > subgroupSize)
166 return op.emitOpError()
167 << "cluster size " << *clusterSize
168 << " is greater than subgroup size " << subgroupSize;
169 unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
170
171 auto clusterStride = op.getClusterStride();
172 assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
173 if (clusterStride >= subgroupSize)
174 return op.emitOpError()
175 << "cluster stride " << clusterStride
176 << " is not less than subgroup size " << subgroupSize;
177
178 return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
179}
180
181/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
182/// and `unpackFn` to convert to the native shuffle type and to the reduction
183/// type, respectively. For example, with `input` of type `f16`, `packFn` could
184/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
185/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
186/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
187/// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
188/// lanes within a cluster, reducing all lanes in each cluster in parallel.
189Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
190 Value input, gpu::AllReduceOperation mode,
191 const ClusterInfo &ci,
192 function_ref<Value(Value)> packFn,
193 function_ref<Value(Value)> unpackFn) {
194 // Lane value always stays in the original type. We use it to perform arith
195 // reductions.
196 Value laneVal = input;
197 // Parallel reduction using butterfly shuffles.
198 for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
199 i <<= 1) {
200 Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i,
201 /*width=*/ci.subgroupSize,
202 /*mode=*/gpu::ShuffleMode::XOR)
203 .getShuffleResult();
204 laneVal = vector::makeArithReduction(builder, loc,
206 laneVal, unpackFn(shuffled));
207 assert(laneVal.getType() == input.getType());
208 }
209
210 return laneVal;
211}
212
213/// Lowers scalar gpu subgroup reductions to a series of shuffles.
214struct ScalarSubgroupReduceToShuffles final
215 : OpRewritePattern<gpu::SubgroupReduceOp> {
216 ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
217 unsigned shuffleBitwidth, bool matchClustered,
218 PatternBenefit benefit)
219 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
220 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
221
222 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
223 PatternRewriter &rewriter) const override {
224 if (op.getClusterSize().has_value() != matchClustered) {
225 return rewriter.notifyMatchFailure(
226 op, llvm::formatv("op is {0}clustered but pattern is configured to "
227 "only match {1}clustered ops",
228 matchClustered ? "non-" : "",
229 matchClustered ? "" : "non-"));
230 }
231
232 auto ci = getAndValidateClusterInfo(op, subgroupSize);
233 if (failed(ci))
234 return failure();
235
236 Type valueTy = op.getType();
237 unsigned elemBitwidth =
239 if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
240 return rewriter.notifyMatchFailure(
241 op, "value type is not a compatible scalar");
242
243 Location loc = op.getLoc();
244 // Since this is already a native shuffle scalar, no packing is necessary.
245 if (elemBitwidth == shuffleBitwidth) {
246 auto identityFn = [](Value v) { return v; };
247 rewriter.replaceOp(op, createSubgroupShuffleReduction(
248 rewriter, loc, op.getValue(), op.getOp(), *ci,
249 identityFn, identityFn));
250 return success();
251 }
252
253 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
254 auto equivIntType = rewriter.getIntegerType(elemBitwidth);
255 auto packFn = [loc, &rewriter, equivIntType,
256 shuffleIntType](Value unpackedVal) -> Value {
257 auto asInt =
258 arith::BitcastOp::create(rewriter, loc, equivIntType, unpackedVal);
259 return arith::ExtUIOp::create(rewriter, loc, shuffleIntType, asInt);
260 };
261 auto unpackFn = [loc, &rewriter, equivIntType,
262 valueTy](Value packedVal) -> Value {
263 auto asInt =
264 arith::TruncIOp::create(rewriter, loc, equivIntType, packedVal);
265 return arith::BitcastOp::create(rewriter, loc, valueTy, asInt);
266 };
267
268 rewriter.replaceOp(
269 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
270 op.getOp(), *ci, packFn, unpackFn));
271 return success();
272 }
273
274private:
275 unsigned subgroupSize = 0;
276 unsigned shuffleBitwidth = 0;
277 bool matchClustered = false;
278};
279
280/// Lowers vector gpu subgroup reductions to a series of shuffles.
281struct VectorSubgroupReduceToShuffles final
282 : OpRewritePattern<gpu::SubgroupReduceOp> {
283 VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
284 unsigned shuffleBitwidth, bool matchClustered,
285 PatternBenefit benefit)
286 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
287 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
288
289 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
290 PatternRewriter &rewriter) const override {
291 if (op.getClusterSize().has_value() != matchClustered) {
292 return rewriter.notifyMatchFailure(
293 op, llvm::formatv("op is {0}clustered but pattern is configured to "
294 "only match {1}clustered ops",
295 matchClustered ? "non-" : "",
296 matchClustered ? "" : "non-"));
297 }
298
299 auto ci = getAndValidateClusterInfo(op, subgroupSize);
300 if (failed(ci))
301 return failure();
302
303 auto vecTy = dyn_cast<VectorType>(op.getType());
304 if (!vecTy)
305 return rewriter.notifyMatchFailure(op, "value type is not a vector");
306
307 unsigned vecBitwidth =
308 vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
309 if (vecBitwidth > shuffleBitwidth)
310 return rewriter.notifyMatchFailure(
311 op,
312 llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
313 "to shuffles of size {1}",
314 vecBitwidth, shuffleBitwidth));
315
316 unsigned elementsPerShuffle =
317 shuffleBitwidth / vecTy.getElementTypeBitWidth();
318 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
319 return rewriter.notifyMatchFailure(
320 op, "shuffle bitwidth is not a multiple of the element bitwidth");
321
322 Location loc = op.getLoc();
323
324 // If the reduced type is smaller than the native shuffle size, extend it,
325 // perform the shuffles, and extract at the end.
326 auto extendedVecTy = VectorType::get(
327 static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
328 Value extendedInput = op.getValue();
329 if (vecBitwidth < shuffleBitwidth) {
330 auto zero = arith::ConstantOp::create(
331 rewriter, loc, rewriter.getZeroAttr(extendedVecTy));
332 extendedInput = vector::InsertStridedSliceOp::create(
333 rewriter, loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
334 }
335
336 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
337 auto shuffleVecType = VectorType::get(1, shuffleIntType);
338
339 auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
340 auto asIntVec =
341 vector::BitCastOp::create(rewriter, loc, shuffleVecType, unpackedVal);
342 return vector::ExtractOp::create(rewriter, loc, asIntVec, 0);
343 };
344 auto unpackFn = [loc, &rewriter, shuffleVecType,
345 extendedVecTy](Value packedVal) -> Value {
346 auto asIntVec =
347 vector::BroadcastOp::create(rewriter, loc, shuffleVecType, packedVal);
348 return vector::BitCastOp::create(rewriter, loc, extendedVecTy, asIntVec);
349 };
350
351 Value res = createSubgroupShuffleReduction(
352 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
353
354 if (vecBitwidth < shuffleBitwidth) {
355 res = vector::ExtractStridedSliceOp::create(
356 rewriter, loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
357 /*strides=*/1);
358 }
359
360 rewriter.replaceOp(op, res);
361 return success();
362 }
363
364private:
365 unsigned subgroupSize = 0;
366 unsigned shuffleBitwidth = 0;
367 bool matchClustered = false;
368};
369
370static FailureOr<Value>
371createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
372 Value input, gpu::AllReduceOperation mode,
373 const ClusterInfo &ci, amdgpu::Chipset chipset) {
374 Location loc = op.getLoc();
375 Value dpp;
376 Value res = input;
377 constexpr int allRows = 0xf;
378 constexpr int allBanks = 0xf;
379 const bool boundCtrl = true;
380 if (ci.clusterSize >= 2) {
381 // Perform reduction between all lanes N <-> N+1.
382 dpp = amdgpu::DPPOp::create(
383 rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
384 rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl);
385 res = vector::makeArithReduction(rewriter, loc,
386 gpu::convertReductionKind(mode), res, dpp);
387 }
388
389 if (ci.clusterSize >= 4) {
390 // Perform reduction between all lanes N <-> N+2.
391 dpp = amdgpu::DPPOp::create(
392 rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm,
393 rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl);
394 res = vector::makeArithReduction(rewriter, loc,
395 gpu::convertReductionKind(mode), res, dpp);
396 }
397 if (ci.clusterSize >= 8) {
398 // Perform reduction between all lanes N <-> 7-N,
399 // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4].
400 dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
401 amdgpu::DPPPerm::row_half_mirror,
402 rewriter.getUnitAttr(), allRows, allBanks,
403 boundCtrl);
404 res = vector::makeArithReduction(rewriter, loc,
405 gpu::convertReductionKind(mode), res, dpp);
406 }
407 if (ci.clusterSize >= 16) {
408 // Perform reduction between all lanes N <-> 15-N,
409 // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8].
410 dpp = amdgpu::DPPOp::create(
411 rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror,
412 rewriter.getUnitAttr(), allRows, allBanks, boundCtrl);
413 res = vector::makeArithReduction(rewriter, loc,
414 gpu::convertReductionKind(mode), res, dpp);
415 }
416 if (ci.clusterSize >= 32) {
417 if (chipset.majorVersion <= 9) {
418 // Broadcast last value from each row to next row.
419 // Use row mask to avoid polluting row 0 (and row 2 if wave-64).
420 dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
421 amdgpu::DPPPerm::row_bcast_15,
422 rewriter.getUnitAttr(), 0xa, allBanks,
423 /*bound_ctrl*/ false);
425 rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
426
427 // For subgroupSize = 64, at this point lanes [16, 32) contain the full
428 // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly,
429 // lanes [48, 64) contain the full reduction over lanes [32, 64), but
430 // lanes [32, 48) do not.
431 //
432 // If subgroup size is 64 and cluster size is 64, we don't need lanes [0,
433 // 16) and [32, 48) to have the correct cluster-32 reduction values at
434 // this point, because only lane 63's value will ultimately be read in
435 // this full-cluster case.
436 //
437 // If subgroup size is 64 and cluster size is 32, we need to ensure that
438 // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction
439 // values (subgroup_reduce guarantees that all lanes within each cluster
440 // contain the final reduction value). We do this by broadcasting lane
441 // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48).
442 //
443 // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations
444 // for an illustration of how this within-cluster broadcast works with a
445 // swizzle.
446 if (ci.subgroupSize == 64 && ci.clusterSize == 32) {
447 res =
448 amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and_mask=*/0,
449 /*or_mask=*/31,
450 /*xor_mask=*/0);
451 }
452 } else if (chipset.majorVersion <= 12) {
453 // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
454 Value uint32Max = arith::ConstantOp::create(
455 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1));
456 dpp = ROCDL::PermlaneX16Op::create(rewriter, loc, res.getType(), res, res,
457 uint32Max, uint32Max,
458 /*fi=*/true,
459 /*boundControl=*/false);
461 rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
462 } else {
463 return rewriter.notifyMatchFailure(
464 op, "Subgroup reduce lowering to DPP not currently supported for "
465 "this device.");
466 }
467 if (ci.subgroupSize == 32) {
468 Value lane31 = arith::ConstantOp::create(
469 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
470 res =
471 ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31);
472 }
473 }
474 if (ci.clusterSize >= 64) {
475 if (chipset.majorVersion <= 9) {
476 // Broadcast 31st lane value to rows 2 and 3.
477 dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
478 amdgpu::DPPPerm::row_bcast_31,
479 rewriter.getUnitAttr(), 0xf, allBanks,
480 /*bound_ctrl*/ true);
482 rewriter, loc, gpu::convertReductionKind(mode), dpp, res);
483 // Obtain reduction from last rows, the previous rows are polluted.
484 Value lane63 = arith::ConstantOp::create(
485 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
486 res =
487 ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63);
488
489 } else if (chipset.majorVersion <= 12) {
490 // Assume reduction across 32 lanes has been done.
491 // Perform final reduction manually by summing values in lane 0 and
492 // lane 32.
493 Value lane31 = arith::ConstantOp::create(
494 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31));
495 Value lane63 = arith::ConstantOp::create(
496 rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63));
497 lane31 =
498 ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31);
499 lane63 =
500 ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63);
502 rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63);
503 } else {
504 return rewriter.notifyMatchFailure(
505 op, "Subgroup reduce lowering to DPP not currently supported for "
506 "this device.");
507 }
508 }
509 assert(res.getType() == input.getType());
510 return res;
511}
512
513/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp`
514/// ops over scalar types. Assumes that the subgroup has
515/// `subgroupSize` lanes. Applicable only to AMD GPUs.
516struct ScalarSubgroupReduceToDPP final
517 : OpRewritePattern<gpu::SubgroupReduceOp> {
518 ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize,
519 bool matchClustered, amdgpu::Chipset chipset,
520 PatternBenefit benefit)
521 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
522 matchClustered(matchClustered), chipset(chipset) {}
523
524 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
525 PatternRewriter &rewriter) const override {
526 if (op.getClusterSize().has_value() != matchClustered) {
527 return rewriter.notifyMatchFailure(
528 op, llvm::formatv("op is {0}clustered but pattern is configured to "
529 "only match {1}clustered ops",
530 matchClustered ? "non-" : "",
531 matchClustered ? "" : "non-"));
532 }
533 auto ci = getAndValidateClusterInfo(op, subgroupSize);
534 if (failed(ci))
535 return failure();
536
537 if (ci->clusterStride != 1)
538 return rewriter.notifyMatchFailure(
539 op, "Subgroup reductions using DPP are currently only available for "
540 "clusters of contiguous lanes.");
541
542 Type valueTy = op.getType();
543 if (!valueTy.isIntOrFloat())
544 return rewriter.notifyMatchFailure(
545 op, "Value type is not a compatible scalar.");
546
547 FailureOr<Value> dpp = createSubgroupDPPReduction(
548 rewriter, op, op.getValue(), op.getOp(), *ci, chipset);
549 if (failed(dpp))
550 return failure();
551
552 rewriter.replaceOp(op, dpp.value());
553 return success();
554 }
555
556private:
557 unsigned subgroupSize = 0;
558 bool matchClustered = false;
559 amdgpu::Chipset chipset;
560};
561} // namespace
562
564 RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
565 PatternBenefit benefit) {
566 patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
567 maxShuffleBitwidth, benefit);
568 patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
569}
570
572 RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
573 PatternBenefit benefit) {
574 patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
575 /*matchClustered=*/false, chipset,
576 benefit);
577}
578
580 RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
581 PatternBenefit benefit) {
582 patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize,
583 /*matchClustered=*/true, chipset,
584 benefit);
585}
586
588 RewritePatternSet &patterns, unsigned subgroupSize,
589 unsigned shuffleBitwidth, PatternBenefit benefit) {
590 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
591 patterns.getContext(), subgroupSize, shuffleBitwidth,
592 /*matchClustered=*/false, benefit);
593}
594
596 RewritePatternSet &patterns, unsigned subgroupSize,
597 unsigned shuffleBitwidth, PatternBenefit benefit) {
598 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
599 patterns.getContext(), subgroupSize, shuffleBitwidth,
600 /*matchClustered=*/true, benefit);
601}
return success()
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:276
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode)
Returns the matching vector combining kind.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
void populateGpuLowerSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into gpu.shuffle ops over shuffleBitwidth scal...
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth=32, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToShufflePatterns that only matches gpu....
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuBreakDownSubgroupReducePatterns(RewritePatternSet &patterns, unsigned maxShuffleBitwidth=32, PatternBenefit benefit=1)
Collect a set of patterns to break down subgroup_reduce ops into smaller ones supported by the target...
void populateGpuLowerSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Collect a set of patterns to lower gpu.subgroup_reduce into amdgpu.dpp ops over scalar types.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Disjoint counterpart of populateGpuLowerSubgroupReduceToDPPPatterns that only matches gpu....
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23