MLIR  19.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15 
16 #define DEBUG_TYPE "int-range-analysis"
17 
18 using namespace mlir;
19 using namespace mlir::arith;
20 using namespace mlir::intrange;
21 
22 //===----------------------------------------------------------------------===//
23 // ConstantOp
24 //===----------------------------------------------------------------------===//
25 
26 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27  SetIntRangeFn setResultRange) {
28  auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
29  if (constAttr) {
30  const APInt &value = constAttr.getValue();
31  setResultRange(getResult(), ConstantIntRanges::constant(value));
32  }
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // AddIOp
37 //===----------------------------------------------------------------------===//
38 
39 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
40  SetIntRangeFn setResultRange) {
41  setResultRange(getResult(), inferAdd(argRanges));
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // SubIOp
46 //===----------------------------------------------------------------------===//
47 
48 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
49  SetIntRangeFn setResultRange) {
50  setResultRange(getResult(), inferSub(argRanges));
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // MulIOp
55 //===----------------------------------------------------------------------===//
56 
57 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58  SetIntRangeFn setResultRange) {
59  setResultRange(getResult(), inferMul(argRanges));
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // DivUIOp
64 //===----------------------------------------------------------------------===//
65 
66 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
67  SetIntRangeFn setResultRange) {
68  setResultRange(getResult(), inferDivU(argRanges));
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // DivSIOp
73 //===----------------------------------------------------------------------===//
74 
75 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76  SetIntRangeFn setResultRange) {
77  setResultRange(getResult(), inferDivS(argRanges));
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // CeilDivUIOp
82 //===----------------------------------------------------------------------===//
83 
84 void arith::CeilDivUIOp::inferResultRanges(
85  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
86  setResultRange(getResult(), inferCeilDivU(argRanges));
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // CeilDivSIOp
91 //===----------------------------------------------------------------------===//
92 
93 void arith::CeilDivSIOp::inferResultRanges(
94  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
95  setResultRange(getResult(), inferCeilDivS(argRanges));
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // FloorDivSIOp
100 //===----------------------------------------------------------------------===//
101 
102 void arith::FloorDivSIOp::inferResultRanges(
103  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
104  return setResultRange(getResult(), inferFloorDivS(argRanges));
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // RemUIOp
109 //===----------------------------------------------------------------------===//
110 
111 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
112  SetIntRangeFn setResultRange) {
113  setResultRange(getResult(), inferRemU(argRanges));
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // RemSIOp
118 //===----------------------------------------------------------------------===//
119 
120 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
121  SetIntRangeFn setResultRange) {
122  setResultRange(getResult(), inferRemS(argRanges));
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // AndIOp
127 //===----------------------------------------------------------------------===//
128 
129 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
130  SetIntRangeFn setResultRange) {
131  setResultRange(getResult(), inferAnd(argRanges));
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // OrIOp
136 //===----------------------------------------------------------------------===//
137 
138 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
139  SetIntRangeFn setResultRange) {
140  setResultRange(getResult(), inferOr(argRanges));
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // XOrIOp
145 //===----------------------------------------------------------------------===//
146 
147 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
148  SetIntRangeFn setResultRange) {
149  setResultRange(getResult(), inferXor(argRanges));
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // MaxSIOp
154 //===----------------------------------------------------------------------===//
155 
156 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
157  SetIntRangeFn setResultRange) {
158  setResultRange(getResult(), inferMaxS(argRanges));
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // MaxUIOp
163 //===----------------------------------------------------------------------===//
164 
165 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
166  SetIntRangeFn setResultRange) {
167  setResultRange(getResult(), inferMaxU(argRanges));
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // MinSIOp
172 //===----------------------------------------------------------------------===//
173 
174 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
175  SetIntRangeFn setResultRange) {
176  setResultRange(getResult(), inferMinS(argRanges));
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // MinUIOp
181 //===----------------------------------------------------------------------===//
182 
183 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
184  SetIntRangeFn setResultRange) {
185  setResultRange(getResult(), inferMinU(argRanges));
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // ExtUIOp
190 //===----------------------------------------------------------------------===//
191 
192 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
193  SetIntRangeFn setResultRange) {
194  unsigned destWidth =
195  ConstantIntRanges::getStorageBitwidth(getResult().getType());
196  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // ExtSIOp
201 //===----------------------------------------------------------------------===//
202 
203 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
204  SetIntRangeFn setResultRange) {
205  unsigned destWidth =
206  ConstantIntRanges::getStorageBitwidth(getResult().getType());
207  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // TruncIOp
212 //===----------------------------------------------------------------------===//
213 
214 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
215  SetIntRangeFn setResultRange) {
216  unsigned destWidth =
217  ConstantIntRanges::getStorageBitwidth(getResult().getType());
218  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // IndexCastOp
223 //===----------------------------------------------------------------------===//
224 
225 void arith::IndexCastOp::inferResultRanges(
226  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
227  Type sourceType = getOperand().getType();
228  Type destType = getResult().getType();
229  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
230  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
231 
232  if (srcWidth < destWidth)
233  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
234  else if (srcWidth > destWidth)
235  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
236  else
237  setResultRange(getResult(), argRanges[0]);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // IndexCastUIOp
242 //===----------------------------------------------------------------------===//
243 
244 void arith::IndexCastUIOp::inferResultRanges(
245  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
246  Type sourceType = getOperand().getType();
247  Type destType = getResult().getType();
248  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
249  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
250 
251  if (srcWidth < destWidth)
252  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
253  else if (srcWidth > destWidth)
254  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
255  else
256  setResultRange(getResult(), argRanges[0]);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // CmpIOp
261 //===----------------------------------------------------------------------===//
262 
263 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
264  SetIntRangeFn setResultRange) {
265  arith::CmpIPredicate arithPred = getPredicate();
266  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
267  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
268 
269  APInt min = APInt::getZero(1);
270  APInt max = APInt::getAllOnes(1);
271 
272  std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
273  if (truthValue.has_value() && *truthValue)
274  min = max;
275  else if (truthValue.has_value() && !(*truthValue))
276  max = min;
277 
278  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // SelectOp
283 //===----------------------------------------------------------------------===//
284 
285 void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
286  SetIntRangeFn setResultRange) {
287  std::optional<APInt> mbCondVal = argRanges[0].getConstantValue();
288 
289  if (mbCondVal) {
290  if (mbCondVal->isZero())
291  setResultRange(getResult(), argRanges[2]);
292  else
293  setResultRange(getResult(), argRanges[1]);
294  return;
295  }
296  setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // ShLIOp
301 //===----------------------------------------------------------------------===//
302 
303 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
304  SetIntRangeFn setResultRange) {
305  setResultRange(getResult(), inferShl(argRanges));
306 }
307 
308 //===----------------------------------------------------------------------===//
309 // ShRUIOp
310 //===----------------------------------------------------------------------===//
311 
312 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
313  SetIntRangeFn setResultRange) {
314  setResultRange(getResult(), inferShrU(argRanges));
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // ShRSIOp
319 //===----------------------------------------------------------------------===//
320 
321 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
322  SetIntRangeFn setResultRange) {
323  setResultRange(getResult(), inferShrS(argRanges));
324 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.