MLIR 22.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Matchers.h"
13#include "llvm/Support/ErrorHandling.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::gpu;
18
19// Maximum grid and block dimensions of all known GPUs are less than 2^32.
20static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
21// Maximum cluster size
22static constexpr uint64_t kMaxClusterDim = 8;
23// Maximum subgroups are no larger than 128.
24static constexpr uint64_t kMaxSubgroupSize = 128;
25
26static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
27 unsigned width = IndexType::kInternalStorageBitWidth;
28 return ConstantIntRanges::fromUnsigned(APInt(width, umin),
29 APInt(width, umax));
30}
31
32namespace {
33enum class LaunchDims : uint32_t { Block = 0, Grid = 1, Cluster = 2 };
34} // end namespace
35
36/// If the operation `op` is in a context that is annotated with maximum
37/// launch dimensions (a launch op with constant block or grid
38/// sizes or a launch_func op with the appropriate dimensions), return
39/// the bound on the maximum size of the dimension that the op is querying.
40/// IDs will be one less than this bound.
41
42static Value valueByDim(KernelDim3 dims, Dimension dim) {
43 switch (dim) {
44 case Dimension::x:
45 return dims.x;
46 case Dimension::y:
47 return dims.y;
48 case Dimension::z:
49 return dims.z;
50 }
51 llvm_unreachable("All dimension enum cases handled above");
52}
53
54static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
55
56static std::optional<uint64_t>
57getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
58 DenseI32ArrayAttr bounds;
59 switch (dims) {
60 case LaunchDims::Block:
61 bounds = func.getKnownBlockSizeAttr();
62 break;
63 case LaunchDims::Grid:
64 bounds = func.getKnownGridSizeAttr();
65 break;
66 case LaunchDims::Cluster:
67 bounds = func.getKnownClusterSizeAttr();
68 break;
69 }
70 if (!bounds)
71 return std::nullopt;
72 if (bounds.size() < static_cast<uint32_t>(dim))
73 return std::nullopt;
74 return zext(bounds[static_cast<uint32_t>(dim)]);
75}
76
77static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
78 StringRef attrName,
79 Dimension dim) {
80 auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
81 if (!bounds)
82 return std::nullopt;
83 if (bounds.size() < static_cast<uint32_t>(dim))
84 return std::nullopt;
85 return zext(bounds[static_cast<uint32_t>(dim)]);
86}
87
88template <typename Op>
89static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
90 Dimension dim = op.getDimension();
91 if (auto launch = op->template getParentOfType<LaunchOp>()) {
92 KernelDim3 bounds;
93 switch (type) {
94 case LaunchDims::Block:
95 bounds = launch.getBlockSizeOperandValues();
96 break;
97 case LaunchDims::Grid:
98 bounds = launch.getGridSizeOperandValues();
99 break;
100 case LaunchDims::Cluster:
101 if (launch.hasClusterSize()) {
102 auto clusterBounds = launch.getClusterSizeOperandValues();
103 if (clusterBounds)
104 bounds = *clusterBounds;
105 }
106 break;
107 }
108 Value maybeBound = valueByDim(bounds, dim);
109 APInt value;
110 if (matchPattern(maybeBound, m_ConstantInt(&value)))
111 return value.getZExtValue();
112 }
113
114 if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
115 auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
116 if (inherentAttr)
117 return inherentAttr;
118 }
119 if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
120 StringRef attrName;
121 switch (type) {
122 case LaunchDims::Block:
123 attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
124 break;
125 case LaunchDims::Grid:
126 attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
127 break;
128 case LaunchDims::Cluster:
129 attrName = GPUDialect::KnownClusterSizeAttrHelper::getNameStr();
130 break;
131 }
132 auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
133 if (discardableAttr)
134 return discardableAttr;
135 }
136 return std::nullopt;
137}
138
139void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
140 SetIntRangeFn setResultRange) {
141 uint64_t max = kMaxDim;
142 if (auto specified = getUpperBound())
143 max = specified->getZExtValue();
144 setResultRange(getResult(), getIndexRange(1, max));
145}
146
147void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
148 SetIntRangeFn setResultRange) {
149 if (auto known = getKnownLaunchDim(*this, LaunchDims::Cluster))
150 return setResultRange(getResult(), getIndexRange(*known, *known));
151
152 uint64_t max = kMaxClusterDim;
153 if (auto specified = getUpperBound())
154 max = specified->getZExtValue();
155 setResultRange(getResult(), getIndexRange(1, max));
156}
157
158void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
159 SetIntRangeFn setResultRange) {
160 uint64_t max = kMaxDim;
161 if (auto specified = getUpperBound())
162 max = specified->getZExtValue();
163 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
164}
165
166void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
167 SetIntRangeFn setResultRange) {
168 uint64_t max = kMaxClusterDim;
169 if (auto known = getKnownLaunchDim(*this, LaunchDims::Cluster))
170 max = *known;
171 if (auto specified = getUpperBound())
172 max = specified->getZExtValue();
173 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
174}
175
176void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
177 SetIntRangeFn setResultRange) {
178 std::optional<uint64_t> knownVal =
179 getKnownLaunchDim(*this, LaunchDims::Block);
180 if (knownVal)
181 return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
182 ;
183 uint64_t max = kMaxDim;
184 if (auto specified = getUpperBound())
185 max = specified->getZExtValue();
186 setResultRange(getResult(), getIndexRange(1, max));
187}
188
189void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
190 SetIntRangeFn setResultRange) {
191 uint64_t max = kMaxDim;
192 if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))
193 max = fromContext.value();
194 if (auto specified = getUpperBound())
195 max = specified->getZExtValue();
196 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
197}
198
199void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
200 SetIntRangeFn setResultRange) {
201 std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
202 if (knownVal)
203 return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
204 uint64_t max = kMaxDim;
205 if (auto specified = getUpperBound())
206 max = specified->getZExtValue();
207 setResultRange(getResult(), getIndexRange(1, max));
208}
209
210void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
211 SetIntRangeFn setResultRange) {
212 uint64_t max = kMaxDim;
213 if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))
214 max = fromContext.value();
215 if (auto specified = getUpperBound())
216 max = specified->getZExtValue();
217 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
218}
219
220void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
221 SetIntRangeFn setResultRange) {
222 uint64_t max = kMaxSubgroupSize;
223 if (auto specified = getUpperBound())
224 max = specified->getZExtValue();
225 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
226}
227
228void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
229 SetIntRangeFn setResultRange) {
230 uint64_t max = kMaxDim;
231 if (auto specified = getUpperBound())
232 max = specified->getZExtValue();
233 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
234}
235
236void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
237 SetIntRangeFn setResultRange) {
238 if (auto specified = getUpperBound())
239 return setResultRange(getResult(),
240 getIndexRange(0, specified->getZExtValue() - 1ULL));
241
242 uint64_t blockDimMax =
243 getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
244 uint64_t gridDimMax =
245 getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
246 setResultRange(getResult(),
247 getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
248}
249
250void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
251 SetIntRangeFn setResultRange) {
252 uint64_t max = kMaxDim;
253 if (auto specified = getUpperBound())
254 max = specified->getZExtValue();
255 setResultRange(getResult(), getIndexRange(1, max));
256}
257
258void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
259 SetIntRangeFn setResultRange) {
260 uint64_t max = kMaxSubgroupSize;
261 if (auto specified = getUpperBound())
262 max = specified->getZExtValue();
263 setResultRange(getResult(), getIndexRange(1, max));
264}
265
266void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
267 SetIntRangeFn setResultRange) {
268 auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
269 Value idxResult) {
270 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
271 return;
272 ConstantIntRanges dimRange =
273 argRange.intersection(getIndexRange(1, kMaxDim));
274 setResultRange(dimResult, dimRange);
275 ConstantIntRanges idxRange =
276 getIndexRange(0, dimRange.umax().getZExtValue() - 1);
277 setResultRange(idxResult, idxRange);
278 };
279
280 argRanges = argRanges.drop_front(getAsyncDependencies().size());
281 KernelDim3 gridDims = getGridSize();
282 KernelDim3 blockIds = getBlockIds();
283 setRange(argRanges[0], gridDims.x, blockIds.x);
284 setRange(argRanges[1], gridDims.y, blockIds.y);
285 setRange(argRanges[2], gridDims.z, blockIds.z);
286 KernelDim3 blockDims = getBlockSize();
287 KernelDim3 threadIds = getThreadIds();
288 setRange(argRanges[3], blockDims.x, threadIds.x);
289 setRange(argRanges[4], blockDims.y, threadIds.y);
290 setRange(argRanges[5], blockDims.z, threadIds.z);
291}
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static Value valueByDim(KernelDim3 dims, Dimension dim)
If the operation op is in a context that is annotated with maximum launch dimensions (a launch op wit...
static constexpr uint64_t kMaxClusterDim
static constexpr uint64_t kMaxDim
static uint64_t zext(uint32_t arg)
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)
static std::optional< uint64_t > getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim)
static constexpr uint64_t kMaxSubgroupSize
static std::optional< uint64_t > getKnownLaunchDim(Op op, LaunchDims type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition Block.h:33
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
This provides public APIs that all operations should have.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39