MLIR  20.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15 
16 #define DEBUG_TYPE "int-range-analysis"
17 
18 using namespace mlir;
19 using namespace mlir::arith;
20 using namespace mlir::intrange;
21 
23 convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
25  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
26  retFlags |= intrange::OverflowFlags::Nsw;
27  if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
28  retFlags |= intrange::OverflowFlags::Nuw;
29  return retFlags;
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // ConstantOp
34 //===----------------------------------------------------------------------===//
35 
36 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
37  SetIntRangeFn setResultRange) {
38  if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
39  const APInt &value = scalarCstAttr.getValue();
40  setResultRange(getResult(), ConstantIntRanges::constant(value));
41  return;
42  }
43  if (auto arrayCstAttr =
44  llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
45  if (arrayCstAttr.isSplat()) {
46  setResultRange(getResult(), ConstantIntRanges::constant(
47  arrayCstAttr.getSplatValue<APInt>()));
48  return;
49  }
50 
51  std::optional<ConstantIntRanges> result;
52  for (const APInt &val : arrayCstAttr) {
53  auto range = ConstantIntRanges::constant(val);
54  result = (result ? result->rangeUnion(range) : range);
55  }
56 
57  assert(result && "Zero-sized vectors are not allowed");
58  setResultRange(getResult(), *result);
59  return;
60  }
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // AddIOp
65 //===----------------------------------------------------------------------===//
66 
67 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
68  SetIntRangeFn setResultRange) {
69  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
70  getOverflowFlags())));
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // SubIOp
75 //===----------------------------------------------------------------------===//
76 
77 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
78  SetIntRangeFn setResultRange) {
79  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
80  getOverflowFlags())));
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // MulIOp
85 //===----------------------------------------------------------------------===//
86 
87 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88  SetIntRangeFn setResultRange) {
89  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
90  getOverflowFlags())));
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // DivUIOp
95 //===----------------------------------------------------------------------===//
96 
97 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
98  SetIntRangeFn setResultRange) {
99  setResultRange(getResult(), inferDivU(argRanges));
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // DivSIOp
104 //===----------------------------------------------------------------------===//
105 
106 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
107  SetIntRangeFn setResultRange) {
108  setResultRange(getResult(), inferDivS(argRanges));
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // CeilDivUIOp
113 //===----------------------------------------------------------------------===//
114 
115 void arith::CeilDivUIOp::inferResultRanges(
116  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
117  setResultRange(getResult(), inferCeilDivU(argRanges));
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // CeilDivSIOp
122 //===----------------------------------------------------------------------===//
123 
124 void arith::CeilDivSIOp::inferResultRanges(
125  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
126  setResultRange(getResult(), inferCeilDivS(argRanges));
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // FloorDivSIOp
131 //===----------------------------------------------------------------------===//
132 
133 void arith::FloorDivSIOp::inferResultRanges(
134  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
135  return setResultRange(getResult(), inferFloorDivS(argRanges));
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // RemUIOp
140 //===----------------------------------------------------------------------===//
141 
142 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
143  SetIntRangeFn setResultRange) {
144  setResultRange(getResult(), inferRemU(argRanges));
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // RemSIOp
149 //===----------------------------------------------------------------------===//
150 
151 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
152  SetIntRangeFn setResultRange) {
153  setResultRange(getResult(), inferRemS(argRanges));
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // AndIOp
158 //===----------------------------------------------------------------------===//
159 
160 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
161  SetIntRangeFn setResultRange) {
162  setResultRange(getResult(), inferAnd(argRanges));
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // OrIOp
167 //===----------------------------------------------------------------------===//
168 
169 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
170  SetIntRangeFn setResultRange) {
171  setResultRange(getResult(), inferOr(argRanges));
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // XOrIOp
176 //===----------------------------------------------------------------------===//
177 
178 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
179  SetIntRangeFn setResultRange) {
180  setResultRange(getResult(), inferXor(argRanges));
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // MaxSIOp
185 //===----------------------------------------------------------------------===//
186 
187 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
188  SetIntRangeFn setResultRange) {
189  setResultRange(getResult(), inferMaxS(argRanges));
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // MaxUIOp
194 //===----------------------------------------------------------------------===//
195 
196 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
197  SetIntRangeFn setResultRange) {
198  setResultRange(getResult(), inferMaxU(argRanges));
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // MinSIOp
203 //===----------------------------------------------------------------------===//
204 
205 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
206  SetIntRangeFn setResultRange) {
207  setResultRange(getResult(), inferMinS(argRanges));
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // MinUIOp
212 //===----------------------------------------------------------------------===//
213 
214 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
215  SetIntRangeFn setResultRange) {
216  setResultRange(getResult(), inferMinU(argRanges));
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // ExtUIOp
221 //===----------------------------------------------------------------------===//
222 
223 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
224  SetIntRangeFn setResultRange) {
225  unsigned destWidth =
227  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // ExtSIOp
232 //===----------------------------------------------------------------------===//
233 
234 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
235  SetIntRangeFn setResultRange) {
236  unsigned destWidth =
238  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // TruncIOp
243 //===----------------------------------------------------------------------===//
244 
245 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
246  SetIntRangeFn setResultRange) {
247  unsigned destWidth =
249  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // IndexCastOp
254 //===----------------------------------------------------------------------===//
255 
256 void arith::IndexCastOp::inferResultRanges(
257  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
258  Type sourceType = getOperand().getType();
259  Type destType = getResult().getType();
260  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
261  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
262 
263  if (srcWidth < destWidth)
264  setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
265  else if (srcWidth > destWidth)
266  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
267  else
268  setResultRange(getResult(), argRanges[0]);
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // IndexCastUIOp
273 //===----------------------------------------------------------------------===//
274 
275 void arith::IndexCastUIOp::inferResultRanges(
276  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
277  Type sourceType = getOperand().getType();
278  Type destType = getResult().getType();
279  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
280  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
281 
282  if (srcWidth < destWidth)
283  setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
284  else if (srcWidth > destWidth)
285  setResultRange(getResult(), truncRange(argRanges[0], destWidth));
286  else
287  setResultRange(getResult(), argRanges[0]);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // CmpIOp
292 //===----------------------------------------------------------------------===//
293 
294 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
295  SetIntRangeFn setResultRange) {
296  arith::CmpIPredicate arithPred = getPredicate();
297  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
298  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
299 
300  APInt min = APInt::getZero(1);
301  APInt max = APInt::getAllOnes(1);
302 
303  std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
304  if (truthValue.has_value() && *truthValue)
305  min = max;
306  else if (truthValue.has_value() && !(*truthValue))
307  max = min;
308 
309  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // SelectOp
314 //===----------------------------------------------------------------------===//
315 
316 void arith::SelectOp::inferResultRangesFromOptional(
317  ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
318  std::optional<APInt> mbCondVal =
319  argRanges[0].isUninitialized()
320  ? std::nullopt
321  : argRanges[0].getValue().getConstantValue();
322 
323  const IntegerValueRange &trueCase = argRanges[1];
324  const IntegerValueRange &falseCase = argRanges[2];
325 
326  if (mbCondVal) {
327  if (mbCondVal->isZero())
328  setResultRange(getResult(), falseCase);
329  else
330  setResultRange(getResult(), trueCase);
331  return;
332  }
333  setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // ShLIOp
338 //===----------------------------------------------------------------------===//
339 
340 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
341  SetIntRangeFn setResultRange) {
342  setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
343  getOverflowFlags())));
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // ShRUIOp
348 //===----------------------------------------------------------------------===//
349 
350 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
351  SetIntRangeFn setResultRange) {
352  setResultRange(getResult(), inferShrU(argRanges));
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // ShRSIOp
357 //===----------------------------------------------------------------------===//
358 
359 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
360  SetIntRangeFn setResultRange) {
361  setResultRange(getResult(), inferShrS(argRanges));
362 }
static intrange::OverflowFlags convertArithOverflowFlags(arith::IntegerOverflowFlags flags)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
This lattice value represents the integer range of an SSA value.
static IntegerValueRange join(const IntegerValueRange &lhs, const IntegerValueRange &rhs)
Compute the least upper bound of two ranges.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305