MLIR  20.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Matchers.h"
13 #include "llvm/ADT/STLForwardCompat.h"
14 #include "llvm/Support/ErrorHandling.h"
15 #include "llvm/Support/MathExtras.h"
16 #include <optional>
17 
18 using namespace mlir;
19 using namespace mlir::gpu;
20 
21 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
22 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
23 // Maximum cluster size
24 static constexpr uint64_t kMaxClusterDim = 8;
25 // Maximum subgroups are no larger than 128.
26 static constexpr uint64_t kMaxSubgroupSize = 128;
27 
28 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
29  unsigned width = IndexType::kInternalStorageBitWidth;
30  return ConstantIntRanges::fromUnsigned(APInt(width, umin),
31  APInt(width, umax));
32 }
33 
34 namespace {
35 enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
36 } // end namespace
37 
38 /// If the operation `op` is in a context that is annotated with maximum
39 /// launch dimensions (a launch op with constant block or grid
40 /// sizes or a launch_func op with the appropriate dimensions), return
41 /// the bound on the maximum size of the dimension that the op is querying.
42 /// IDs will be one less than this bound.
43 
44 static Value valueByDim(KernelDim3 dims, Dimension dim) {
45  switch (dim) {
46  case Dimension::x:
47  return dims.x;
48  case Dimension::y:
49  return dims.y;
50  case Dimension::z:
51  return dims.z;
52  }
53  llvm_unreachable("All dimension enum cases handled above");
54 }
55 
56 static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
57 
58 static std::optional<uint64_t>
59 getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) {
60  DenseI32ArrayAttr bounds;
61  switch (dims) {
62  case LaunchDims::Block:
63  bounds = func.getKnownBlockSizeAttr();
64  break;
65  case LaunchDims::Grid:
66  bounds = func.getKnownGridSizeAttr();
67  break;
68  }
69  if (!bounds)
70  return std::nullopt;
71  if (bounds.size() < static_cast<uint32_t>(dim))
72  return std::nullopt;
73  return zext(bounds[static_cast<uint32_t>(dim)]);
74 }
75 
76 static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
77  StringRef attrName,
78  Dimension dim) {
79  auto bounds = func.getOperation()->getAttrOfType<DenseI32ArrayAttr>(attrName);
80  if (!bounds)
81  return std::nullopt;
82  if (bounds.size() < static_cast<uint32_t>(dim))
83  return std::nullopt;
84  return zext(bounds[static_cast<uint32_t>(dim)]);
85 }
86 
87 template <typename Op>
88 static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
89  Dimension dim = op.getDimension();
90  if (auto launch = op->template getParentOfType<LaunchOp>()) {
91  KernelDim3 bounds;
92  switch (type) {
93  case LaunchDims::Block:
94  bounds = launch.getBlockSizeOperandValues();
95  break;
96  case LaunchDims::Grid:
97  bounds = launch.getGridSizeOperandValues();
98  break;
99  }
100  Value maybeBound = valueByDim(bounds, dim);
101  APInt value;
102  if (matchPattern(maybeBound, m_ConstantInt(&value)))
103  return value.getZExtValue();
104  }
105 
106  if (auto gpuFunc = op->template getParentOfType<GPUFuncOp>()) {
107  auto inherentAttr = getKnownLaunchAttr(gpuFunc, type, dim);
108  if (inherentAttr)
109  return inherentAttr;
110  }
111  if (auto func = op->template getParentOfType<FunctionOpInterface>()) {
112  StringRef attrName;
113  switch (type) {
114  case LaunchDims::Block:
115  attrName = GPUDialect::KnownBlockSizeAttrHelper::getNameStr();
116  break;
117  case LaunchDims::Grid:
118  attrName = GPUDialect::KnownGridSizeAttrHelper::getNameStr();
119  break;
120  }
121  auto discardableAttr = getKnownLaunchAttr(func, attrName, dim);
122  if (discardableAttr)
123  return discardableAttr;
124  }
125  return std::nullopt;
126 }
127 
128 void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
129  SetIntRangeFn setResultRange) {
130  uint64_t max = kMaxDim;
131  if (auto specified = getUpperBound())
132  max = specified->getZExtValue();
133  setResultRange(getResult(), getIndexRange(1, max));
134 }
135 
136 void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
137  SetIntRangeFn setResultRange) {
138  uint64_t max = kMaxClusterDim;
139  if (auto specified = getUpperBound())
140  max = specified->getZExtValue();
141  setResultRange(getResult(), getIndexRange(1, max));
142 }
143 
144 void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
145  SetIntRangeFn setResultRange) {
146  uint64_t max = kMaxDim;
147  if (auto specified = getUpperBound())
148  max = specified->getZExtValue();
149  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
150 }
151 
152 void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
153  SetIntRangeFn setResultRange) {
154  uint64_t max = kMaxClusterDim;
155  if (auto specified = getUpperBound())
156  max = specified->getZExtValue();
157  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
158 }
159 
160 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
161  SetIntRangeFn setResultRange) {
162  std::optional<uint64_t> knownVal =
163  getKnownLaunchDim(*this, LaunchDims::Block);
164  if (knownVal)
165  return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
166  ;
167  uint64_t max = kMaxDim;
168  if (auto specified = getUpperBound())
169  max = specified->getZExtValue();
170  setResultRange(getResult(), getIndexRange(1, max));
171 }
172 
173 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
174  SetIntRangeFn setResultRange) {
175  uint64_t max = kMaxDim;
176  if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Grid))
177  max = fromContext.value();
178  if (auto specified = getUpperBound())
179  max = specified->getZExtValue();
180  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
181 }
182 
183 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
184  SetIntRangeFn setResultRange) {
185  std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
186  if (knownVal)
187  return setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
188  uint64_t max = kMaxDim;
189  if (auto specified = getUpperBound())
190  max = specified->getZExtValue();
191  setResultRange(getResult(), getIndexRange(1, max));
192 }
193 
194 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
195  SetIntRangeFn setResultRange) {
196  uint64_t max = kMaxDim;
197  if (auto fromContext = getKnownLaunchDim(*this, LaunchDims::Block))
198  max = fromContext.value();
199  if (auto specified = getUpperBound())
200  max = specified->getZExtValue();
201  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
202 }
203 
204 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
205  SetIntRangeFn setResultRange) {
206  uint64_t max = kMaxSubgroupSize;
207  if (auto specified = getUpperBound())
208  max = specified->getZExtValue();
209  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
210 }
211 
212 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
213  SetIntRangeFn setResultRange) {
214  uint64_t max = kMaxDim;
215  if (auto specified = getUpperBound())
216  max = specified->getZExtValue();
217  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
218 }
219 
220 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
221  SetIntRangeFn setResultRange) {
222  if (auto specified = getUpperBound())
223  return setResultRange(getResult(),
224  getIndexRange(0, specified->getZExtValue() - 1ULL));
225 
226  uint64_t blockDimMax =
227  getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
228  uint64_t gridDimMax =
229  getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
230  setResultRange(getResult(),
231  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
232 }
233 
234 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
235  SetIntRangeFn setResultRange) {
236  uint64_t max = kMaxDim;
237  if (auto specified = getUpperBound())
238  max = specified->getZExtValue();
239  setResultRange(getResult(), getIndexRange(1, max));
240 }
241 
242 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
243  SetIntRangeFn setResultRange) {
244  uint64_t max = kMaxSubgroupSize;
245  if (auto specified = getUpperBound())
246  max = specified->getZExtValue();
247  setResultRange(getResult(), getIndexRange(1, max));
248 }
249 
250 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
251  SetIntRangeFn setResultRange) {
252  auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
253  Value idxResult) {
254  if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
255  return;
256  ConstantIntRanges dimRange =
257  argRange.intersection(getIndexRange(1, kMaxDim));
258  setResultRange(dimResult, dimRange);
259  ConstantIntRanges idxRange =
260  getIndexRange(0, dimRange.umax().getZExtValue() - 1);
261  setResultRange(idxResult, idxRange);
262  };
263 
264  argRanges = argRanges.drop_front(getAsyncDependencies().size());
265  KernelDim3 gridDims = getGridSize();
266  KernelDim3 blockIds = getBlockIds();
267  setRange(argRanges[0], gridDims.x, blockIds.x);
268  setRange(argRanges[1], gridDims.y, blockIds.y);
269  setRange(argRanges[2], gridDims.z, blockIds.z);
270  KernelDim3 blockDims = getBlockSize();
271  KernelDim3 threadIds = getThreadIds();
272  setRange(argRanges[3], blockDims.x, threadIds.x);
273  setRange(argRanges[4], blockDims.y, threadIds.y);
274  setRange(argRanges[5], blockDims.z, threadIds.z);
275 }
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static Value valueByDim(KernelDim3 dims, Dimension dim)
If the operation op is in a context that is annotated with maximum launch dimensions (a launch op wit...
static constexpr uint64_t kMaxClusterDim
static std::optional< uint64_t > getKnownLaunchDim(Op op, LaunchDims type)
static constexpr uint64_t kMaxDim
static uint64_t zext(uint32_t arg)
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)
static constexpr uint64_t kMaxSubgroupSize
static std::optional< uint64_t > getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition: Block.h:31
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
This provides public APIs that all operations should have.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38