MLIR 23.0.0git
SubgroupIdRewriter.cpp
Go to the documentation of this file.
1//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
10// where:
11// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
12//
13//===----------------------------------------------------------------------===//
14
19#include "mlir/IR/Builders.h"
21
22using namespace mlir;
23
24namespace {
25struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
26 using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
27
28 LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
29 PatternRewriter &rewriter) const override {
30 // Calculation of the thread's subgroup identifier.
31 //
32 // The process involves mapping the thread's 3D identifier within its
33 // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
34 // This linearization assumes a layout where the x-dimension (w_dim.x)
35 // varies most rapidly (i.e., it is the innermost dimension).
36 //
37 // The formula for the linearized thread index is:
38 // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
39 //
40 // Subsequently, the range of linearized indices [0, N_threads-1] is
41 // divided into consecutive, non-overlapping segments, each representing
42 // a subgroup of size 'subgroup_size'.
43 //
44 // Example Partitioning (N = subgroup_size):
45 // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
46 // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
47 //
48 // The subgroup identifier is obtained via integer division of the
49 // linearized thread index by the predefined 'subgroup_size'.
50 //
51 // subgroup_id = floor( L / subgroup_size )
52 // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
53 // subgroup_size
54
55 Location loc = op->getLoc();
56 Type indexType = rewriter.getIndexType();
57
58 auto asMaybeIndexAttr = [&](std::optional<uint32_t> bound) -> IntegerAttr {
59 if (!bound)
60 return IntegerAttr();
61 return IntegerAttr::get(
62 indexType, static_cast<int64_t>(static_cast<uint64_t>(*bound)));
63 };
64
65 IntegerAttr maybeKnownDimX =
66 asMaybeIndexAttr(gpu::getKnownDimensionSizeAround(
67 op, gpu::DimensionKind::Block, gpu::Dimension::x));
68 IntegerAttr maybeKnownDimY =
69 asMaybeIndexAttr(gpu::getKnownDimensionSizeAround(
70 op, gpu::DimensionKind::Block, gpu::Dimension::y));
71 IntegerAttr maybeKnownDimZ =
72 asMaybeIndexAttr(gpu::getKnownDimensionSizeAround(
73 op, gpu::DimensionKind::Block, gpu::Dimension::z));
74
75 Value dimX, dimY;
76 if (maybeKnownDimX)
77 dimX = arith::ConstantOp::create(rewriter, loc, maybeKnownDimX);
78 else
79 dimX = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::x);
80 if (maybeKnownDimY)
81 dimY = arith::ConstantOp::create(rewriter, loc, maybeKnownDimY);
82 else
83 dimY = gpu::BlockDimOp::create(rewriter, loc, gpu::Dimension::y);
84
85 Value tidX = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x,
86 maybeKnownDimX);
87 Value tidY = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::y,
88 maybeKnownDimY);
89 Value tidZ = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::z,
90 maybeKnownDimZ);
91
92 // Block dimensions don't exceed a signed int32_t maximum, and neither does
93 // their product, on any realistic hardware, nor would any targets compile
94 // with index < 32 bits, so we can assert no overflow.
95 auto flags =
96 arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw;
97 Value dimYxIdZ =
98 arith::MulIOp::create(rewriter, loc, indexType, dimY, tidZ, flags);
99 Value dimYxIdZPlusIdY =
100 arith::AddIOp::create(rewriter, loc, indexType, dimYxIdZ, tidY, flags);
101 Value dimYxIdZPlusIdYTimesDimX = arith::MulIOp::create(
102 rewriter, loc, indexType, dimX, dimYxIdZPlusIdY, flags);
103 Value idXPlusDimYxIdZPlusIdYTimesDimX = arith::AddIOp::create(
104 rewriter, loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX, flags);
105 Value subgroupSize = gpu::SubgroupSizeOp::create(
106 rewriter, loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
107 Value subgroupIdOp =
108 arith::DivUIOp::create(rewriter, loc, indexType,
109 idXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
110 rewriter.replaceOp(op, {subgroupIdOp});
111 return success();
112 }
113};
114
115} // namespace
116
118 patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
119}
return success()
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::optional< uint32_t > getKnownDimensionSizeAround(Operation *op, DimensionKind kind, Dimension dim)
Retrieve the constant bounds for a given dimension and dimension kind from the context surrounding op...
Include the generated interface declarations.
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns)
Collect a set of patterns to rewrite SubgroupIdOp op within the GPU dialect.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...