MLIR  16.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 using namespace mlir::gpu;
14 
15 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
16 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
17 // Maximum subgroups are no larger than 128.
18 static constexpr uint64_t kMaxSubgroupSize = 128;
19 
20 static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
21  unsigned width = IndexType::kInternalStorageBitWidth;
22  return ConstantIntRanges::fromUnsigned(APInt(width, umin),
23  APInt(width, umax));
24 }
25 
26 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
27  SetIntRangeFn setResultRange) {
28  setResultRange(getResult(), getIndexRange(1, kMaxDim));
29 }
30 
31 void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
32  SetIntRangeFn setResultRange) {
33  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
34 }
35 
36 void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
37  SetIntRangeFn setResultRange) {
38  setResultRange(getResult(), getIndexRange(1, kMaxDim));
39 }
40 
41 void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
42  SetIntRangeFn setResultRange) {
43  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
44 }
45 
46 void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
47  SetIntRangeFn setResultRange) {
48  setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1));
49 }
50 
51 void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
52  SetIntRangeFn setResultRange) {
53  setResultRange(getResult(), getIndexRange(0, kMaxDim - 1));
54 }
55 
56 void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
57  SetIntRangeFn setResultRange) {
58  setResultRange(getResult(),
60 }
61 
62 void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
63  SetIntRangeFn setResultRange) {
64  setResultRange(getResult(), getIndexRange(1, kMaxDim));
65 }
66 
67 void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
68  SetIntRangeFn setResultRange) {
69  setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
70 }
71 
72 void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
73  SetIntRangeFn setResultRange) {
74  auto setRange = [&](ConstantIntRanges argRange, Value dimResult,
75  Value idxResult) {
76  if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
77  return;
78  ConstantIntRanges dimRange =
79  argRange.intersection(getIndexRange(1, kMaxDim));
80  setResultRange(dimResult, dimRange);
81  ConstantIntRanges idxRange =
82  getIndexRange(0, dimRange.umax().getZExtValue() - 1);
83  setResultRange(idxResult, idxRange);
84  };
85 
86  argRanges = argRanges.drop_front(asyncDependencies().size());
87  KernelDim3 gridDims = getGridSize();
88  KernelDim3 blockIds = getBlockIds();
89  setRange(argRanges[0], gridDims.x, blockIds.x);
90  setRange(argRanges[1], gridDims.y, blockIds.y);
91  setRange(argRanges[2], gridDims.z, blockIds.z);
92  KernelDim3 blockDims = getBlockSize();
93  KernelDim3 threadIds = getThreadIds();
94  setRange(argRanges[3], blockDims.x, threadIds.x);
95  setRange(argRanges[4], blockDims.y, threadIds.y);
96  setRange(argRanges[5], blockDims.z, threadIds.z);
97 }
Include the generated interface declarations.
static constexpr uint64_t kMaxSubgroupSize
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax)
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
Utility class for the GPU dialect to represent triples of Values accessible through ...
Definition: GPUDialect.h:34
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static constexpr uint64_t kMaxDim
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
A set of arbitrary-precision integers representing bounds on a given integer value.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)