MLIR  20.0.0git
InferIntRangeCommon.cpp
Go to the documentation of this file.
1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 #include "llvm/Support/Debug.h"
22 
23 #include <iterator>
24 #include <optional>
25 
26 using namespace mlir;
27 
28 #define DEBUG_TYPE "int-range-analysis"
29 
30 //===----------------------------------------------------------------------===//
31 // General utilities
32 //===----------------------------------------------------------------------===//
33 
34 /// Function that evaluates the result of doing something on arithmetic
35 /// constants and returns std::nullopt on overflow.
36 using ConstArithFn =
37  function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
39  std::function<std::optional<APInt>(const APInt &, const APInt &)>;
40 
41 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
42 /// If either computation overflows, make the result unbounded.
43 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
44  const APInt &minRight,
45  const APInt &maxLeft,
46  const APInt &maxRight, bool isSigned) {
47  std::optional<APInt> maybeMin = op(minLeft, minRight);
48  std::optional<APInt> maybeMax = op(maxLeft, maxRight);
49  if (maybeMin && maybeMax)
50  return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
51  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
52 }
53 
54 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
55 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
57  ArrayRef<APInt> rhs, bool isSigned) {
58  unsigned width = lhs[0].getBitWidth();
59  APInt min =
60  isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
61  APInt max =
62  isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
63  for (const APInt &left : lhs) {
64  for (const APInt &right : rhs) {
65  std::optional<APInt> maybeThisResult = op(left, right);
66  if (!maybeThisResult)
67  return ConstantIntRanges::maxRange(width);
68  APInt result = std::move(*maybeThisResult);
69  min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
70  max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
71  }
72  }
73  return ConstantIntRanges::range(min, max, isSigned);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Ext, trunc, index op handling
78 //===----------------------------------------------------------------------===//
79 
83  intrange::CmpMode mode) {
84  ConstantIntRanges sixtyFour = inferFn(argRanges);
86  llvm::transform(argRanges, std::back_inserter(truncated),
87  [](const ConstantIntRanges &range) {
88  return truncRange(range, /*destWidth=*/indexMinWidth);
89  });
90  ConstantIntRanges thirtyTwo = inferFn(truncated);
91  ConstantIntRanges thirtyTwoAsSixtyFour =
92  extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
93  ConstantIntRanges sixtyFourAsThirtyTwo =
94  truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
95 
96  LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
97  << " 32-bit = " << thirtyTwo << "\n");
98  bool truncEqual = false;
99  switch (mode) {
101  truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
102  break;
104  truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
105  thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
106  break;
108  truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
109  thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
110  break;
111  }
112  if (truncEqual)
113  // Returing the 64-bit result preserves more information.
114  return sixtyFour;
115  ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
116  return merged;
117 }
118 
120  unsigned int destWidth) {
121  APInt umin = range.umin().zext(destWidth);
122  APInt umax = range.umax().zext(destWidth);
123  APInt smin = range.smin().sext(destWidth);
124  APInt smax = range.smax().sext(destWidth);
125  return {umin, umax, smin, smax};
126 }
127 
129  unsigned destWidth) {
130  APInt umin = range.umin().zext(destWidth);
131  APInt umax = range.umax().zext(destWidth);
132  return ConstantIntRanges::fromUnsigned(umin, umax);
133 }
134 
136  unsigned destWidth) {
137  APInt smin = range.smin().sext(destWidth);
138  APInt smax = range.smax().sext(destWidth);
139  return ConstantIntRanges::fromSigned(smin, smax);
140 }
141 
143  unsigned int destWidth) {
144  // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
145  // the range of the resulting value is not contiguous ind includes 0.
146  // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
147  // but you can't truncate [255, 257] similarly.
148  bool hasUnsignedRollover =
149  range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
150  APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
151  : range.umin().trunc(destWidth);
152  APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
153  : range.umax().trunc(destWidth);
154 
155  // Signed post-truncation rollover will not occur when either:
156  // - The high parts of the min and max, plus the sign bit, are the same
157  // - The high halves + sign bit of the min and max are either all 1s or all 0s
158  // and you won't create a [positive, negative] range by truncating.
159  // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
160  // but not [255, 257]_i16 to a range of i8s. You can also truncate
161  // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
162  // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
163  // will truncate to 0x7e, which is greater than 0
164  APInt sminHighPart = range.smin().ashr(destWidth - 1);
165  APInt smaxHighPart = range.smax().ashr(destWidth - 1);
166  bool hasSignedOverflow =
167  (sminHighPart != smaxHighPart) &&
168  !(sminHighPart.isAllOnes() &&
169  (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
170  !(sminHighPart.isZero() && smaxHighPart.isZero());
171  APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
172  : range.smin().trunc(destWidth);
173  APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
174  : range.smax().trunc(destWidth);
175  return {umin, umax, smin, smax};
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // Addition
180 //===----------------------------------------------------------------------===//
181 
184  OverflowFlags ovfFlags) {
185  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
186 
187  ConstArithStdFn uadd = [=](const APInt &a,
188  const APInt &b) -> std::optional<APInt> {
189  bool overflowed = false;
190  APInt result = any(ovfFlags & OverflowFlags::Nuw)
191  ? a.uadd_sat(b)
192  : a.uadd_ov(b, overflowed);
193  return overflowed ? std::optional<APInt>() : result;
194  };
195  ConstArithStdFn sadd = [=](const APInt &a,
196  const APInt &b) -> std::optional<APInt> {
197  bool overflowed = false;
198  APInt result = any(ovfFlags & OverflowFlags::Nsw)
199  ? a.sadd_sat(b)
200  : a.sadd_ov(b, overflowed);
201  return overflowed ? std::optional<APInt>() : result;
202  };
203 
205  uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
207  sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
208  return urange.intersection(srange);
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // Subtraction
213 //===----------------------------------------------------------------------===//
214 
217  OverflowFlags ovfFlags) {
218  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
219 
220  ConstArithStdFn usub = [=](const APInt &a,
221  const APInt &b) -> std::optional<APInt> {
222  bool overflowed = false;
223  APInt result = any(ovfFlags & OverflowFlags::Nuw)
224  ? a.usub_sat(b)
225  : a.usub_ov(b, overflowed);
226  return overflowed ? std::optional<APInt>() : result;
227  };
228  ConstArithStdFn ssub = [=](const APInt &a,
229  const APInt &b) -> std::optional<APInt> {
230  bool overflowed = false;
231  APInt result = any(ovfFlags & OverflowFlags::Nsw)
232  ? a.ssub_sat(b)
233  : a.ssub_ov(b, overflowed);
234  return overflowed ? std::optional<APInt>() : result;
235  };
237  usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
239  ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
240  return urange.intersection(srange);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Multiplication
245 //===----------------------------------------------------------------------===//
246 
249  OverflowFlags ovfFlags) {
250  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
251 
252  ConstArithStdFn umul = [=](const APInt &a,
253  const APInt &b) -> std::optional<APInt> {
254  bool overflowed = false;
255  APInt result = any(ovfFlags & OverflowFlags::Nuw)
256  ? a.umul_sat(b)
257  : a.umul_ov(b, overflowed);
258  return overflowed ? std::optional<APInt>() : result;
259  };
260  ConstArithStdFn smul = [=](const APInt &a,
261  const APInt &b) -> std::optional<APInt> {
262  bool overflowed = false;
263  APInt result = any(ovfFlags & OverflowFlags::Nsw)
264  ? a.smul_sat(b)
265  : a.smul_ov(b, overflowed);
266  return overflowed ? std::optional<APInt>() : result;
267  };
268 
269  ConstantIntRanges urange =
270  minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
271  /*isSigned=*/false);
272  ConstantIntRanges srange =
273  minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
274  /*isSigned=*/true);
275  return urange.intersection(srange);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // DivU, CeilDivU (Unsigned division)
280 //===----------------------------------------------------------------------===//
281 
282 /// Fix up division results (ex. for ceiling and floor), returning an APInt
283 /// if there has been no overflow
285  const APInt &lhs, const APInt &rhs, const APInt &result)>;
286 
288  const ConstantIntRanges &rhs,
289  DivisionFixupFn fixup) {
290  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
291  &rhsMax = rhs.umax();
292 
293  if (!rhsMin.isZero()) {
294  auto udiv = [&fixup](const APInt &a,
295  const APInt &b) -> std::optional<APInt> {
296  return fixup(a, b, a.udiv(b));
297  };
298  return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
299  /*isSigned=*/false);
300  }
301 
302  APInt umin = APInt::getZero(rhsMin.getBitWidth());
303  if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
304  umin = lhsMin.udiv(rhsMax);
305 
306  // X u/ Y u<= X.
307  APInt umax = lhsMax;
308  return ConstantIntRanges::fromUnsigned(umin, umax);
309 }
310 
313  return inferDivURange(argRanges[0], argRanges[1],
314  [](const APInt &lhs, const APInt &rhs,
315  const APInt &result) { return result; });
316 }
317 
320  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
321 
322  auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
323  const APInt &result) -> std::optional<APInt> {
324  if (!lhs.urem(rhs).isZero()) {
325  bool overflowed = false;
326  APInt corrected =
327  result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
328  return overflowed ? std::optional<APInt>() : corrected;
329  }
330  return result;
331  };
332  return inferDivURange(lhs, rhs, ceilDivUIFix);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // DivS, CeilDivS, FloorDivS (Signed division)
337 //===----------------------------------------------------------------------===//
338 
340  const ConstantIntRanges &rhs,
341  DivisionFixupFn fixup) {
342  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
343  &rhsMax = rhs.smax();
344  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
345 
346  if (canDivide) {
347  auto sdiv = [&fixup](const APInt &a,
348  const APInt &b) -> std::optional<APInt> {
349  bool overflowed = false;
350  APInt result = a.sdiv_ov(b, overflowed);
351  return overflowed ? std::optional<APInt>() : fixup(a, b, result);
352  };
353  return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
354  /*isSigned=*/true);
355  }
356  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
357 }
358 
361  return inferDivSRange(argRanges[0], argRanges[1],
362  [](const APInt &lhs, const APInt &rhs,
363  const APInt &result) { return result; });
364 }
365 
368  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
369 
370  auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
371  const APInt &result) -> std::optional<APInt> {
372  if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
373  bool overflowed = false;
374  APInt corrected =
375  result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
376  return overflowed ? std::optional<APInt>() : corrected;
377  }
378  // Special case where the usual implementation of ceilDiv causes
379  // INT_MIN / [positive number] to be positive. This doesn't match the
380  // definition of signed ceiling division mathematically, but it prevents
381  // inconsistent constant-folding results. This arises because (-int_min) is
382  // still negative, so -(-int_min / b) is -(int_min / b), which is
383  // positive See #115293.
384  if (lhs.isMinSignedValue() && rhs.sgt(1)) {
385  return -result;
386  }
387  return result;
388  };
389  return inferDivSRange(lhs, rhs, ceilDivSIFix);
390 }
391 
394  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
395 
396  auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
397  const APInt &result) -> std::optional<APInt> {
398  if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
399  bool overflowed = false;
400  APInt corrected =
401  result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
402  return overflowed ? std::optional<APInt>() : corrected;
403  }
404  return result;
405  };
406  return inferDivSRange(lhs, rhs, floorDivSIFix);
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // Signed remainder (RemS)
411 //===----------------------------------------------------------------------===//
412 
415  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
416  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
417  &rhsMax = rhs.smax();
418 
419  unsigned width = rhsMax.getBitWidth();
420  APInt smin = APInt::getSignedMinValue(width);
421  APInt smax = APInt::getSignedMaxValue(width);
422  // No bounds if zero could be a divisor.
423  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
424  if (canBound) {
425  APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
426  bool canNegativeDividend = lhsMin.isNegative();
427  bool canPositiveDividend = lhsMax.isStrictlyPositive();
428  APInt zero = APInt::getZero(maxDivisor.getBitWidth());
429  APInt maxPositiveResult = maxDivisor - 1;
430  APInt minNegativeResult = -maxPositiveResult;
431  smin = canNegativeDividend ? minNegativeResult : zero;
432  smax = canPositiveDividend ? maxPositiveResult : zero;
433  // Special case: sweeping out a contiguous range in N/[modulus].
434  if (rhsMin == rhsMax) {
435  if ((lhsMax - lhsMin).ult(maxDivisor)) {
436  APInt minRem = lhsMin.srem(maxDivisor);
437  APInt maxRem = lhsMax.srem(maxDivisor);
438  if (minRem.sle(maxRem)) {
439  smin = minRem;
440  smax = maxRem;
441  }
442  }
443  }
444  }
445  return ConstantIntRanges::fromSigned(smin, smax);
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // Unsigned remainder (RemU)
450 //===----------------------------------------------------------------------===//
451 
454  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
455  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
456 
457  unsigned width = rhsMin.getBitWidth();
458  APInt umin = APInt::getZero(width);
459  // Remainder can't be larger than either of its arguments.
460  APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
461 
462  if (!rhsMin.isZero()) {
463  // Special case: sweeping out a contiguous range in N/[modulus]
464  if (rhsMin == rhsMax) {
465  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
466  if ((lhsMax - lhsMin).ult(rhsMax)) {
467  APInt minRem = lhsMin.urem(rhsMax);
468  APInt maxRem = lhsMax.urem(rhsMax);
469  if (minRem.ule(maxRem)) {
470  umin = minRem;
471  umax = maxRem;
472  }
473  }
474  }
475  }
476  return ConstantIntRanges::fromUnsigned(umin, umax);
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // Max and min (MaxS, MaxU, MinS, MinU)
481 //===----------------------------------------------------------------------===//
482 
485  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
486 
487  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
488  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
489  return ConstantIntRanges::fromSigned(smin, smax);
490 }
491 
494  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
495 
496  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
497  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
498  return ConstantIntRanges::fromUnsigned(umin, umax);
499 }
500 
503  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
504 
505  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
506  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
507  return ConstantIntRanges::fromSigned(smin, smax);
508 }
509 
512  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
513 
514  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
515  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
516  return ConstantIntRanges::fromUnsigned(umin, umax);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // Bitwise operators (And, Or, Xor)
521 //===----------------------------------------------------------------------===//
522 
523 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
524 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
525 /// that both bonuds have in common. This gives us a consertive approximation
526 /// for what values can be passed to bitwise operations.
527 static std::tuple<APInt, APInt>
529  APInt leftVal = bound.umin(), rightVal = bound.umax();
530  unsigned bitwidth = leftVal.getBitWidth();
531  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
532  leftVal.clearLowBits(differingBits);
533  rightVal.setLowBits(differingBits);
534  return std::make_tuple(std::move(leftVal), std::move(rightVal));
535 }
536 
539  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
540  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
541  auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
542  return a & b;
543  };
544  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
545  /*isSigned=*/false);
546 }
547 
550  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
551  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
552  auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
553  return a | b;
554  };
555  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
556  /*isSigned=*/false);
557 }
558 
559 /// Get bitmask of all bits which can change while iterating in
560 /// [bound.umin(), bound.umax()].
561 static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
562  APInt leftVal = bound.umin(), rightVal = bound.umax();
563  unsigned bitwidth = leftVal.getBitWidth();
564  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
565  return APInt::getLowBitsSet(bitwidth, differingBits);
566 }
567 
570  // Construct mask of varying bits for both ranges, xor values and then replace
571  // masked bits with 0s and 1s to get min and max values respectively.
572  ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
573  APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
574  APInt res = lhs.umin() ^ rhs.umin();
575  APInt min = res & ~mask;
576  APInt max = res | mask;
578 }
579 
580 //===----------------------------------------------------------------------===//
581 // Shifts (Shl, ShrS, ShrU)
582 //===----------------------------------------------------------------------===//
583 
586  OverflowFlags ovfFlags) {
587  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
588  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
589 
590  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
591  // 2^rhs.
592  ConstArithStdFn ushl = [=](const APInt &l,
593  const APInt &r) -> std::optional<APInt> {
594  bool overflowed = false;
595  APInt result = any(ovfFlags & OverflowFlags::Nuw)
596  ? l.ushl_sat(r)
597  : l.ushl_ov(r, overflowed);
598  return overflowed ? std::optional<APInt>() : result;
599  };
600  ConstArithStdFn sshl = [=](const APInt &l,
601  const APInt &r) -> std::optional<APInt> {
602  bool overflowed = false;
603  APInt result = any(ovfFlags & OverflowFlags::Nsw)
604  ? l.sshl_sat(r)
605  : l.sshl_ov(r, overflowed);
606  return overflowed ? std::optional<APInt>() : result;
607  };
608 
609  ConstantIntRanges urange =
610  minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
611  /*isSigned=*/false);
612  ConstantIntRanges srange =
613  minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
614  /*isSigned=*/true);
615  return urange.intersection(srange);
616 }
617 
620  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
621 
622  auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
623  return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
624  };
625 
626  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
627  /*isSigned=*/true);
628 }
629 
632  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
633 
634  auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
635  return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
636  };
637  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
638  /*isSigned=*/false);
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // Comparisons (Cmp)
643 //===----------------------------------------------------------------------===//
644 
646  switch (pred) {
667  }
668  llvm_unreachable("unknown cmp predicate value");
669 }
670 
672  const ConstantIntRanges &lhs,
673  const ConstantIntRanges &rhs) {
674  switch (pred) {
676  return lhs.smax().sle(rhs.smin());
678  return lhs.smax().slt(rhs.smin());
680  return lhs.umax().ule(rhs.umin());
682  return lhs.umax().ult(rhs.umin());
684  return lhs.smin().sge(rhs.smax());
686  return lhs.smin().sgt(rhs.smax());
688  return lhs.umin().uge(rhs.umax());
690  return lhs.umin().ugt(rhs.umax());
692  std::optional<APInt> lhsConst = lhs.getConstantValue();
693  std::optional<APInt> rhsConst = rhs.getConstantValue();
694  return lhsConst && rhsConst && lhsConst == rhsConst;
695  }
697  // While equality requires that there is an interpration of the preceeding
698  // computations that produces equal constants, whether that be signed or
699  // unsigned, statically determining inequality requires that neither
700  // interpretation produce potentially overlapping ranges.
701  bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
703  bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
705  return sne && une;
706  }
707  }
708  return false;
709 }
710 
712  const ConstantIntRanges &lhs,
713  const ConstantIntRanges &rhs) {
714  if (isStaticallyTrue(pred, lhs, rhs))
715  return true;
716  if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
717  return false;
718  return std::nullopt;
719 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111,...
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.