MLIR  22.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 #include <optional>
14 
15 #define DEBUG_TYPE "int-range-analysis"
16 
17 using namespace mlir;
18 using namespace mlir::arith;
19 using namespace mlir::intrange;
20 
22 convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
24  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
25  retFlags |= intrange::OverflowFlags::Nsw;
26  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
27  retFlags |= intrange::OverflowFlags::Nuw;
28  return retFlags;
29 }
30 
31 //===----------------------------------------------------------------------===//
32 // ConstantOp
33 //===----------------------------------------------------------------------===//
34 
35 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
36  SetIntRangeFn setResultRange) {
37  if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
38  const APInt &value = scalarCstAttr.getValue();
39  setResultRange(getResult(), ConstantIntRanges::constant(value));
40  return;
41  }
42  if (auto arrayCstAttr =
43  llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
44  if (arrayCstAttr.isSplat()) {
45  setResultRange(getResult(), ConstantIntRanges::constant(
46  arrayCstAttr.getSplatValue<APInt>()));
47  return;
48  }
49 
50  std::optional<ConstantIntRanges> result;
51  for (const APInt &val : arrayCstAttr) {
52  auto range = ConstantIntRanges::constant(val);
53  result = (result ? result->rangeUnion(range) : range);
54  }
55 
56  assert(result && "Zero-sized vectors are not allowed");
57  setResultRange(getResult(), *result);
58  return;
59  }
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // AddIOp
64 //===----------------------------------------------------------------------===//
65 
66 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
67  SetIntRangeFn setResultRange) {
68  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
69  getOverflowFlags())));
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // SubIOp
74 //===----------------------------------------------------------------------===//
75 
76 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
77  SetIntRangeFn setResultRange) {
78  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
79  getOverflowFlags())));
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // MulIOp
84 //===----------------------------------------------------------------------===//
85 
86 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87  SetIntRangeFn setResultRange) {
88  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
89  getOverflowFlags())));
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // DivUIOp
94 //===----------------------------------------------------------------------===//
95 
96 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
97  SetIntRangeFn setResultRange) {
98  setResultRange(getResult(), inferDivU(argRanges));
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // DivSIOp
103 //===----------------------------------------------------------------------===//
104 
105 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106  SetIntRangeFn setResultRange) {
107  setResultRange(getResult(), inferDivS(argRanges));
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // CeilDivUIOp
112 //===----------------------------------------------------------------------===//
113 
114 void arith::CeilDivUIOp::inferResultRanges(
115  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
116  setResultRange(getResult(), inferCeilDivU(argRanges));
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // CeilDivSIOp
121 //===----------------------------------------------------------------------===//
122 
123 void arith::CeilDivSIOp::inferResultRanges(
124  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
125  setResultRange(getResult(), inferCeilDivS(argRanges));
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // FloorDivSIOp
130 //===----------------------------------------------------------------------===//
131 
132 void arith::FloorDivSIOp::inferResultRanges(
133  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
134  return setResultRange(getResult(), inferFloorDivS(argRanges));
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // RemUIOp
139 //===----------------------------------------------------------------------===//
140 
141 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142  SetIntRangeFn setResultRange) {
143  setResultRange(getResult(), inferRemU(argRanges));
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // RemSIOp
148 //===----------------------------------------------------------------------===//
149 
150 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
151  SetIntRangeFn setResultRange) {
152  setResultRange(getResult(), inferRemS(argRanges));
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // AndIOp
157 //===----------------------------------------------------------------------===//
158 
159 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160  SetIntRangeFn setResultRange) {
161  setResultRange(getResult(), inferAnd(argRanges));
162 }
163 
164 //===----------------------------------------------------------------------===//
165 // OrIOp
166 //===----------------------------------------------------------------------===//
167 
168 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
169  SetIntRangeFn setResultRange) {
170  setResultRange(getResult(), inferOr(argRanges));
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // XOrIOp
175 //===----------------------------------------------------------------------===//
176 
177 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
178  SetIntRangeFn setResultRange) {
179  setResultRange(getResult(), inferXor(argRanges));
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // MaxSIOp
184 //===----------------------------------------------------------------------===//
185 
186 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
187  SetIntRangeFn setResultRange) {
188  setResultRange(getResult(), inferMaxS(argRanges));
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // MaxUIOp
193 //===----------------------------------------------------------------------===//
194 
195 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
196  SetIntRangeFn setResultRange) {
197  setResultRange(getResult(), inferMaxU(argRanges));
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // MinSIOp
202 //===----------------------------------------------------------------------===//
203 
204 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
205  SetIntRangeFn setResultRange) {
206  setResultRange(getResult(), inferMinS(argRanges));
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // MinUIOp
211 //===----------------------------------------------------------------------===//
212 
213 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
214  SetIntRangeFn setResultRange) {
215  setResultRange(getResult(), inferMinU(argRanges));
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // ExtUIOp
220 //===----------------------------------------------------------------------===//
221 
222 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
223  SetIntRangeFn setResultRange) {
224  unsigned destWidth =
226  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // ExtSIOp
231 //===----------------------------------------------------------------------===//
232 
233 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
234  SetIntRangeFn setResultRange) {
235  unsigned destWidth =
237  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // TruncIOp
242 //===----------------------------------------------------------------------===//
243 
244 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
245  SetIntRangeFn setResultRange) {
246  unsigned destWidth =
248  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // IndexCastOp
253 //===----------------------------------------------------------------------===//
254 
255 void arith::IndexCastOp::inferResultRanges(
256  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
257  Type sourceType = getOperand().getType();
258  Type destType = getResult().getType();
259  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
260  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
261 
262  if (srcWidth < destWidth)
263  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
264  else if (srcWidth > destWidth)
265  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
266  else
267  setResultRange(getResult(), argRanges[0]);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // IndexCastUIOp
272 //===----------------------------------------------------------------------===//
273 
274 void arith::IndexCastUIOp::inferResultRanges(
275  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
276  Type sourceType = getOperand().getType();
277  Type destType = getResult().getType();
278  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
279  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
280 
281  if (srcWidth < destWidth)
282  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
283  else if (srcWidth > destWidth)
284  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
285  else
286  setResultRange(getResult(), argRanges[0]);
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // CmpIOp
291 //===----------------------------------------------------------------------===//
292 
293 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
294  SetIntRangeFn setResultRange) {
295  arith::CmpIPredicate arithPred = getPredicate();
296  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
297  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
298 
299  APInt min = APInt::getZero(1);
300  APInt max = APInt::getAllOnes(1);
301 
302  std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
303  if (truthValue.has_value() && *truthValue)
304  min = max;
305  else if (truthValue.has_value() && !(*truthValue))
306  max = min;
307 
308  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // SelectOp
313 //===----------------------------------------------------------------------===//
314 
315 void arith::SelectOp::inferResultRangesFromOptional(
316  ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
317  std::optional<APInt> mbCondVal =
318  argRanges[0].isUninitialized()
319  ? std::nullopt
320  : argRanges[0].getValue().getConstantValue();
321 
322  const IntegerValueRange &trueCase = argRanges[1];
323  const IntegerValueRange &falseCase = argRanges[2];
324 
325  if (mbCondVal) {
326  if (mbCondVal->isZero())
327  setResultRange(getResult(), falseCase);
328  else
329  setResultRange(getResult(), trueCase);
330  return;
331  }
332  setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // ShLIOp
337 //===----------------------------------------------------------------------===//
338 
339 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
340  SetIntRangeFn setResultRange) {
341  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
342  getOverflowFlags())));
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // ShRUIOp
347 //===----------------------------------------------------------------------===//
348 
349 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
350  SetIntRangeFn setResultRange) {
351  setResultRange(getResult(), inferShrU(argRanges));
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // ShRSIOp
356 //===----------------------------------------------------------------------===//
357 
358 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
359  SetIntRangeFn setResultRange) {
360  setResultRange(getResult(), inferShrS(argRanges));
361 }
static intrange::OverflowFlags convertArithOverflowFlags(arith::IntegerOverflowFlags flags)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
This lattice value represents the integer range of an SSA value.
static IntegerValueRange join(const IntegerValueRange &lhs, const IntegerValueRange &rhs)
Compute the least upper bound of two ranges.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304