MLIR 23.0.0git
ValueBoundsOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14
15using namespace mlir;
16using namespace mlir::gpu;
17
18namespace {
19/// Implement ValueBoundsOpInterface (which only works on index-typed values,
20/// gathers a set of constraint expressions, and is used for affine analyses)
21/// in terms of InferIntRangeInterface (which works
22/// on arbitrary integer types, creates [min, max] ranges, and is used in for
23/// arithmetic simplification).
24template <typename Op>
25struct GpuIdOpInterface
26 : public ValueBoundsOpInterface::ExternalModel<GpuIdOpInterface<Op>, Op> {
27 void populateBoundsForIndexValue(Operation *op, Value value,
28 ValueBoundsConstraintSet &cstr) const {
29 auto inferrable = cast<InferIntRangeInterface>(op);
30 assert(value == op->getResult(0) &&
31 "inferring for value that isn't the GPU op's result");
32 auto translateConstraint = [&](Value v, const ConstantIntRanges &range) {
33 assert(v == value &&
34 "GPU ID op inferring values for something that's not its result");
35 cstr.bound(v) >= range.smin().getSExtValue();
36 cstr.bound(v) <= range.smax().getSExtValue();
37 };
38 assert(inferrable->getNumOperands() == 0 && "ID ops have no operands");
39 inferrable.inferResultRanges({}, translateConstraint);
40 }
41};
42
43/// Implement ValueBoundsOpInterface on subgroup broadcast operations to
44/// indicate that such a broadcast does not modify the ranges of the values in
45/// question. Handles shaped types just in case one wants to broadcast a memref
46/// descriptor.
47struct SubgroupBroadcastOpInterface
48 : public ValueBoundsOpInterface::ExternalModel<SubgroupBroadcastOpInterface,
49 SubgroupBroadcastOp> {
50 void populateBoundsForIndexValue(Operation *op, Value value,
51 ValueBoundsConstraintSet &cstr) const {
52 auto broadcastOp = cast<SubgroupBroadcastOp>(op);
53 assert(value == broadcastOp.getResult() && "invalid value");
54 cstr.bound(value) == cstr.getExpr(broadcastOp.getSrc());
55 }
56
57 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
58 ValueBoundsConstraintSet &cstr) const {
59 auto broadcastOp = cast<SubgroupBroadcastOp>(op);
60 assert(value == broadcastOp.getResult() && "invalid value");
61 cstr.bound(value)[dim] == cstr.getExpr(broadcastOp.getSrc(), dim);
62 }
63};
64
65struct GpuLaunchOpInterface
66 : public ValueBoundsOpInterface::ExternalModel<GpuLaunchOpInterface,
67 LaunchOp> {
68 void populateBoundsForIndexValue(Operation *op, Value value,
69 ValueBoundsConstraintSet &cstr) const {
70 auto launchOp = cast<LaunchOp>(op);
71
72 Value sizeArg = nullptr;
73 bool isSize = false;
74 KernelDim3 gridSizeArgs = launchOp.getGridSizeOperandValues();
75 KernelDim3 blockSizeArgs = launchOp.getBlockSizeOperandValues();
76
77 auto match = [&](KernelDim3 bodyArgs, KernelDim3 externalArgs,
78 bool areSizeArgs) {
79 if (value == bodyArgs.x) {
80 sizeArg = externalArgs.x;
81 isSize = areSizeArgs;
82 }
83 if (value == bodyArgs.y) {
84 sizeArg = externalArgs.y;
85 isSize = areSizeArgs;
86 }
87 if (value == bodyArgs.z) {
88 sizeArg = externalArgs.z;
89 isSize = areSizeArgs;
90 }
91 };
92 match(launchOp.getThreadIds(), blockSizeArgs, false);
93 match(launchOp.getBlockSize(), blockSizeArgs, true);
94 match(launchOp.getBlockIds(), gridSizeArgs, false);
95 match(launchOp.getGridSize(), gridSizeArgs, true);
96 if (launchOp.hasClusterSize()) {
97 KernelDim3 clusterSizeArgs = *launchOp.getClusterSizeOperandValues();
98 match(*launchOp.getClusterIds(), clusterSizeArgs, false);
99 match(*launchOp.getClusterSize(), clusterSizeArgs, true);
100 }
101
102 if (!sizeArg)
103 return;
104 if (isSize) {
105 cstr.bound(value) == cstr.getExpr(sizeArg);
106 cstr.bound(value) >= 1;
107 } else {
108 cstr.bound(value) < cstr.getExpr(sizeArg);
109 cstr.bound(value) >= 0;
110 }
111 }
112};
113} // namespace
114
116 DialectRegistry &registry) {
117 registry.addExtension(+[](MLIRContext *ctx, GPUDialect *dialect) {
118#define REGISTER(X) X::attachInterface<GpuIdOpInterface<X>>(*ctx);
119 REGISTER(ClusterDimOp)
120 REGISTER(ClusterDimBlocksOp)
121 REGISTER(ClusterIdOp)
122 REGISTER(ClusterBlockIdOp)
123 REGISTER(BlockDimOp)
124 REGISTER(BlockIdOp)
125 REGISTER(GridDimOp)
126 REGISTER(ThreadIdOp)
127 REGISTER(LaneIdOp)
128 REGISTER(SubgroupIdOp)
129 REGISTER(GlobalIdOp)
130 REGISTER(NumSubgroupsOp)
131 REGISTER(SubgroupSizeOp)
132#undef REGISTER
133
134 LaunchOp::attachInterface<GpuLaunchOpInterface>(*ctx);
135 SubgroupBroadcastOp::attachInterface<SubgroupBroadcastOpInterface>(*ctx);
136 });
137}
#define REGISTER(X)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:415
AffineExpr getExpr(Value value, std::optional< int64_t > dim=std::nullopt)
Return an expression that represents the given index-typed value or shaped value dimension.
BoundBuilder bound(Value value)
Add a bound for the given index-typed value or shaped value.
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.