MLIR 23.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Matchers.h"
13#include <optional>
14
15using namespace mlir;
16using namespace mlir::gpu;
17
18// Maximum grid and block dimensions of all known GPUs are less than 2^32.
19static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
20// Maximum cluster size.
21static constexpr uint64_t kMaxClusterDim = 16;
22// Maximum subgroups are no larger than 128.
23static constexpr uint64_t kMaxSubgroupSize = 128;
24
25static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
26 unsigned width = IndexType::kInternalStorageBitWidth;
27 return ConstantIntRanges::fromUnsigned(APInt(width, umin),
28 APInt(width, umax));
29}
30
31static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
32
33static Value valueByDim(KernelDim3 dims, Dimension dim) {
34 switch (dim) {
35 case Dimension::x:
36 return dims.x;
37 case Dimension::y:
38 return dims.y;
39 case Dimension::z:
40 return dims.z;
41 }
42 llvm_unreachable("All dimension enum cases handled above");
43}
44
45static std::optional<uint32_t>
46getKnownLaunchAttr(GPUFuncOp func, DimensionKind dims, Dimension dim) {
47 DenseI32ArrayAttr bounds;
48 switch (dims) {
49 case DimensionKind::Other:
50 return std::nullopt;
51 case DimensionKind::Block:
52 bounds = func.getKnownBlockSizeAttr();
53 break;
54 case DimensionKind::Grid:
55 bounds = func.getKnownGridSizeAttr();
56 break;
57 case DimensionKind::Cluster:
58 bounds = func.getKnownClusterSizeAttr();
59 break;
60 }
61 if (!bounds)
62 return std::nullopt;
63 if (bounds.size() <= static_cast<uint32_t>(dim))
64 return std::nullopt;
65 return bounds[static_cast<uint32_t>(dim)];
66}
67
68static std::optional<uint32_t> getKnownLaunchAttr(FunctionOpInterface func,
69 StringRef attrName,
70 Dimension dim) {
71 auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
72 if (!bounds)
73 return std::nullopt;
74 if (bounds.size() <= static_cast<uint32_t>(dim))
75 return std::nullopt;
76 return bounds[static_cast<uint32_t>(dim)];
77}
78
79std::optional<uint32_t>
81 Dimension dim) {
82 if (auto launch = op->getParentOfType<LaunchOp>()) {
83 KernelDim3 bounds;
84 switch (kind) {
85 case DimensionKind::Other:
86 return std::nullopt;
87 case DimensionKind::Block:
88 bounds = launch.getBlockSizeOperandValues();
89 break;
90 case DimensionKind::Grid:
91 bounds = launch.getGridSizeOperandValues();
92 break;
93 case DimensionKind::Cluster:
94 if (launch.hasClusterSize()) {
95 auto clusterBounds = launch.getClusterSizeOperandValues();
96 if (clusterBounds)
97 bounds = *clusterBounds;
98 }
99 break;
100 }
101 Value maybeBound = valueByDim(bounds, dim);
102 APInt value;
103 if (maybeBound && matchPattern(maybeBound, m_ConstantInt(&value)))
104 return value.getZExtValue();
105 }
106
107 if (auto gpuFunc = op->getParentOfType<GPUFuncOp>()) {
108 auto inherentAttr = getKnownLaunchAttr(gpuFunc, kind, dim);
109 if (inherentAttr)
110 return inherentAttr;
111 }
112 if (auto func = op->getParentOfType<FunctionOpInterface>()) {
113 StringRef attrName;
114 switch (kind) {
115 case DimensionKind::Other:
116 return std::nullopt;
117 case DimensionKind::Block:
118 attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
119 break;
120 case DimensionKind::Grid:
121 attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
122 break;
123 case DimensionKind::Cluster:
124 attrName = GPUDialect::KnownClusterSizeAttrHelper::getNameStr();
125 break;
126 }
127 auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
128 if (discardableAttr)
129 return discardableAttr;
130 }
131 return std::nullopt;
132}
133
134void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
135 SetIntRangeFn setResultRange) {
136 uint64_t max = kMaxDim;
137 if (auto specified = getUpperBound())
138 max = specified->getZExtValue();
139 setResultRange(getResult(), getIndexRange(1, max));
140}
141
142void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
143 SetIntRangeFn setResultRange) {
144 if (auto known = getKnownDimensionSizeAround(*this, DimensionKind::Cluster,
145 getDimension()))
146 return setResultRange(getResult(),
147 getIndexRange(zext(*known), zext(*known)));
148
149 uint64_t max = kMaxClusterDim;
150 if (auto specified = getUpperBound())
151 max = specified->getZExtValue();
152 setResultRange(getResult(), getIndexRange(1, max));
153}
154
155void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
156 SetIntRangeFn setResultRange) {
157 uint64_t max = kMaxDim;
158 if (auto specified = getUpperBound())
159 max = specified->getZExtValue();
160 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
161}
162
163void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
164 SetIntRangeFn setResultRange) {
165 uint64_t max = kMaxClusterDim;
166 if (auto known = getKnownDimensionSizeAround(*this, DimensionKind::Cluster,
167 getDimension()))
168 max = zext(*known);
169 if (auto specified = getUpperBound())
170 max = specified->getZExtValue();
171 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
172}
173
174void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
175 SetIntRangeFn setResultRange) {
176 std::optional<uint32_t> knownVal =
177 getKnownDimensionSizeAround(*this, DimensionKind::Block, getDimension());
178 if (knownVal)
179 return setResultRange(getResult(),
180 getIndexRange(zext(*knownVal), zext(*knownVal)));
181
182 uint64_t max = kMaxDim;
183 if (auto specified = getUpperBound())
184 max = specified->getZExtValue();
185 setResultRange(getResult(), getIndexRange(1, max));
186}
187
188void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
189 SetIntRangeFn setResultRange) {
190 uint64_t max = kMaxDim;
191 if (auto fromContext = getKnownDimensionSizeAround(*this, DimensionKind::Grid,
192 getDimension()))
193 max = zext(*fromContext);
194 if (auto specified = getUpperBound())
195 max = specified->getZExtValue();
196 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
197}
198
199void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
200 SetIntRangeFn setResultRange) {
201 std::optional<uint32_t> knownVal =
202 getKnownDimensionSizeAround(*this, DimensionKind::Grid, getDimension());
203 if (knownVal)
204 return setResultRange(getResult(),
205 getIndexRange(zext(*knownVal), zext(*knownVal)));
206 uint64_t max = kMaxDim;
207 if (auto specified = getUpperBound())
208 max = specified->getZExtValue();
209 setResultRange(getResult(), getIndexRange(1, max));
210}
211
212void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
213 SetIntRangeFn setResultRange) {
214 uint64_t max = kMaxDim;
215 if (auto fromContext = getKnownDimensionSizeAround(
216 *this, DimensionKind::Block, getDimension()))
217 max = zext(*fromContext);
218 if (auto specified = getUpperBound())
219 max = specified->getZExtValue();
220 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
221}
222
223void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
224 SetIntRangeFn setResultRange) {
225 uint64_t max = kMaxSubgroupSize;
226 if (auto specified = getUpperBound())
227 max = specified->getZExtValue();
228 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
229}
230
231void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
232 SetIntRangeFn setResultRange) {
233 uint64_t max = kMaxDim;
234 if (auto specified = getUpperBound())
235 max = specified->getZExtValue();
236 setResultRange(getResult(), getIndexRange(0, max - 1ULL));
237}
238
239void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
240 SetIntRangeFn setResultRange) {
241 if (auto specified = getUpperBound())
242 return setResultRange(getResult(),
243 getIndexRange(0, specified->getZExtValue() - 1ULL));
244
245 uint64_t blockDimMax = zext(
246 getKnownDimensionSizeAround(*this, DimensionKind::Block, getDimension())
247 .value_or(kMaxDim));
248 uint64_t gridDimMax = zext(
249 getKnownDimensionSizeAround(*this, DimensionKind::Grid, getDimension())
250 .value_or(kMaxDim));
251 setResultRange(getResult(),
252 getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
253}
254
255void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
256 SetIntRangeFn setResultRange) {
257 uint64_t max = kMaxDim;
258 if (auto specified = getUpperBound())
259 max = specified->getZExtValue();
260 setResultRange(getResult(), getIndexRange(1, max));
261}
262
263void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
264 SetIntRangeFn setResultRange) {
265 uint64_t max = kMaxSubgroupSize;
266 if (auto specified = getUpperBound())
267 max = specified->getZExtValue();
268 setResultRange(getResult(), getIndexRange(1, max));
269}
270
271void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
272 SetIntRangeFn setResultRange) {
273 auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
274 Value idxResult) {
275 if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
276 return;
277 ConstantIntRanges dimRange =
278 argRange.intersection(getIndexRange(1, kMaxDim));
279 setResultRange(dimResult, dimRange);
280 ConstantIntRanges idxRange =
281 getIndexRange(0, dimRange.umax().getZExtValue() - 1);
282 setResultRange(idxResult, idxRange);
283 };
284
285 argRanges = argRanges.drop_front(getAsyncDependencies().size());
286 KernelDim3 gridDims = getGridSize();
287 KernelDim3 blockIds = getBlockIds();
288 setRange(argRanges[0], gridDims.x, blockIds.x);
289 setRange(argRanges[1], gridDims.y, blockIds.y);
290 setRange(argRanges[2], gridDims.z, blockIds.z);
291 KernelDim3 blockDims = getBlockSize();
292 KernelDim3 threadIds = getThreadIds();
293 setRange(argRanges[3], blockDims.x, threadIds.x);
294 setRange(argRanges[4], blockDims.y, threadIds.y);
295 setRange(argRanges[5], blockDims.z, threadIds.z);
296}
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< uint32_t > getKnownLaunchAttr(GPUFuncOp func, DimensionKind dims, Dimension dim)
static Value valueByDim(KernelDim3 dims, Dimension dim)
static constexpr uint64_t kMaxClusterDim
static constexpr uint64_t kMaxDim
static uint64_t zext(uint32_t arg)
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)
static constexpr uint64_t kMaxSubgroupSize
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
std::optional< uint32_t > getKnownDimensionSizeAround(Operation *op, DimensionKind kind, Dimension dim)
Retrieve the constant bounds for a given dimension and dimension kind from the context surrounding op...
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition GPUDialect.h:39