12 #include "llvm/ADT/STLForwardCompat.h"
13 #include "llvm/Support/ErrorHandling.h"
14 #include "llvm/Support/MathExtras.h"
28 unsigned width = IndexType::kInternalStorageBitWidth;
34 enum class LaunchDims : uint32_t {
Block = 0, Grid = 1 };
52 llvm_unreachable(
"All dimension enum cases handled above");
55 static uint64_t
zext(uint32_t arg) {
return static_cast<uint64_t
>(arg); }
57 template <
typename Op>
59 Dimension dim = op.getDimension();
60 if (
auto launch = op->template getParentOfType<LaunchOp>()) {
63 case LaunchDims::Block:
64 bounds = launch.getBlockSizeOperandValues();
66 case LaunchDims::Grid:
67 bounds = launch.getGridSizeOperandValues();
73 return value.getZExtValue();
76 if (
auto func = op->template getParentOfType<GPUFuncOp>()) {
78 case LaunchDims::Block:
79 return llvm::transformOptional(func.getKnownBlockSize(dim),
zext);
80 case LaunchDims::Grid:
81 return llvm::transformOptional(func.getKnownGridSize(dim),
zext);
100 std::optional<uint64_t> knownVal =
103 setResultRange(getResult(),
getIndexRange(*knownVal, *knownVal));
118 setResultRange(getResult(),
getIndexRange(*knownVal, *knownVal));
141 uint64_t blockDimMax =
143 uint64_t gridDimMax =
145 setResultRange(getResult(),
163 if (argRange.
umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
167 setResultRange(dimResult, dimRange);
170 setResultRange(idxResult, idxRange);
173 argRanges = argRanges.drop_front(getAsyncDependencies().size());
176 setRange(argRanges[0], gridDims.
x, blockIds.
x);
177 setRange(argRanges[1], gridDims.
y, blockIds.
y);
178 setRange(argRanges[2], gridDims.
z, blockIds.
z);
181 setRange(argRanges[3], blockDims.
x, threadIds.
x);
182 setRange(argRanges[4], blockDims.
y, threadIds.
y);
183 setRange(argRanges[5], blockDims.
z, threadIds.
z);
static Value valueByDim(KernelDim3 dims, Dimension dim)
If the operation op is in a context that is annotated with maximum launch dimensions (a launch op wit...
static constexpr uint64_t kMaxClusterDim
static std::optional< uint64_t > getKnownLaunchDim(Op op, LaunchDims type)
static constexpr uint64_t kMaxDim
static uint64_t zext(uint32_t arg)
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)
static constexpr uint64_t kMaxSubgroupSize
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
This provides public APIs that all operations should have.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Utility class for the GPU dialect to represent triples of Values accessible through ....