MLIR  19.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Matchers.h"
12 #include "llvm/ADT/STLForwardCompat.h"
13 #include "llvm/Support/ErrorHandling.h"
14 #include "llvm/Support/MathExtras.h"
15 #include <optional>
16 
17 using namespace mlir;
18 using namespace mlir::gpu;
19 
20 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
21 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
22 // Maximum cluster size
23 static constexpr uint64_t kMaxClusterDim = 8;
24 // Maximum subgroups are no larger than 128.
25 static constexpr uint64_t kMaxSubgroupSize = 128;
26 
27 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
28  unsigned width = IndexType::kInternalStorageBitWidth;
29  return ConstantIntRanges::fromUnsigned(APInt(width, umin),
30  APInt(width, umax));
31 }
32 
33 namespace {
34 enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
35 } // end namespace
36 
37 /// If the operation `op` is in a context that is annotated with maximum
38 /// launch dimensions (a launch op with constant block or grid
39 /// sizes or a launch_func op with the appropriate dimensions), return
40 /// the bound on the maximum size of the dimension that the op is querying.
41 /// IDs will be one less than this bound.
42 
43 static Value valueByDim(KernelDim3 dims, Dimension dim) {
44  switch (dim) {
45  case Dimension::x:
46  return dims.x;
47  case Dimension::y:
48  return dims.y;
49  case Dimension::z:
50  return dims.z;
51  }
52  llvm_unreachable("All dimension enum cases handled above");
53 }
54 
55 static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
56 
57 template <typename Op>
58 static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
59  Dimension dim = op.getDimension();
60  if (auto launch = op->template getParentOfType<LaunchOp>()) {
61  KernelDim3 bounds;
62  switch (type) {
63  case LaunchDims::Block:
64  bounds = launch.getBlockSizeOperandValues();
65  break;
66  case LaunchDims::Grid:
67  bounds = launch.getGridSizeOperandValues();
68  break;
69  }
70  Value maybeBound = valueByDim(bounds, dim);
71  APInt value;
72  if (matchPattern(maybeBound, m_ConstantInt(&value)))
73  return value.getZExtValue();
74  }
75 
76  if (auto func = op->template getParentOfType<GPUFuncOp>()) {
77  switch (type) {
78  case LaunchDims::Block:
79  return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
80  case LaunchDims::Grid:
81  return llvm::transformOptional(func.getKnownGridSize(dim), zext);
82  }
83  }
84  return std::nullopt;
85 }
86 
87 void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
88  SetIntRangeFn setResultRange) {
89  setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
90 }
91 
92 void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
93  SetIntRangeFn setResultRange) {
94  uint64_t max = kMaxClusterDim;
95  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
96 }
97 
98 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
99  SetIntRangeFn setResultRange) {
100  std::optional<uint64_t> knownVal =
101  getKnownLaunchDim(*this, LaunchDims::Block);
102  if (knownVal)
103  setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
104  else
105  setResultRange(getResult(), getIndexRange(1, kMaxDim));
106 }
107 
108 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
109  SetIntRangeFn setResultRange) {
110  uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
111  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
112 }
113 
114 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
115  SetIntRangeFn setResultRange) {
116  std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
117  if (knownVal)
118  setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
119  else
120  setResultRange(getResult(), getIndexRange(1, kMaxDim));
121 }
122 
123 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
124  SetIntRangeFn setResultRange) {
125  uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
126  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
127 }
128 
129 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
130  SetIntRangeFn setResultRange) {
131  setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
132 }
133 
134 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
135  SetIntRangeFn setResultRange) {
136  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
137 }
138 
139 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
140  SetIntRangeFn setResultRange) {
141  uint64_t blockDimMax =
142  getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
143  uint64_t gridDimMax =
144  getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
145  setResultRange(getResult(),
146  getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
147 }
148 
149 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
150  SetIntRangeFn setResultRange) {
151  setResultRange(getResult(), getIndexRange(1, kMaxDim));
152 }
153 
154 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
155  SetIntRangeFn setResultRange) {
156  setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
157 }
158 
159 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160  SetIntRangeFn setResultRange) {
161  auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
162  Value idxResult) {
163  if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
164  return;
165  ConstantIntRanges dimRange =
166  argRange.intersection(getIndexRange(1, kMaxDim));
167  setResultRange(dimResult, dimRange);
168  ConstantIntRanges idxRange =
169  getIndexRange(0, dimRange.umax().getZExtValue() - 1);
170  setResultRange(idxResult, idxRange);
171  };
172 
173  argRanges = argRanges.drop_front(getAsyncDependencies().size());
174  KernelDim3 gridDims = getGridSize();
175  KernelDim3 blockIds = getBlockIds();
176  setRange(argRanges[0], gridDims.x, blockIds.x);
177  setRange(argRanges[1], gridDims.y, blockIds.y);
178  setRange(argRanges[2], gridDims.z, blockIds.z);
179  KernelDim3 blockDims = getBlockSize();
180  KernelDim3 threadIds = getThreadIds();
181  setRange(argRanges[3], blockDims.x, threadIds.x);
182  setRange(argRanges[4], blockDims.y, threadIds.y);
183  setRange(argRanges[5], blockDims.z, threadIds.z);
184 }
static Value valueByDim(KernelDim3 dims, Dimension dim)
If the operation op is in a context that is annotated with maximum launch dimensions (a launch op wit...
static constexpr uint64_t kMaxClusterDim
static std::optional< uint64_t > getKnownLaunchDim(Op op, LaunchDims type)
static constexpr uint64_t kMaxDim
static uint64_t zext(uint32_t arg)
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)
static constexpr uint64_t kMaxSubgroupSize
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Definition: Block.h:30
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
This provides public APIs that all operations should have.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:38