MLIR  22.0.0git
InferIntRangeCommon.cpp
Go to the documentation of this file.
1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 #include "llvm/Support/Debug.h"
23 
24 #include <iterator>
25 #include <optional>
26 
27 using namespace mlir;
28 
29 #define DEBUG_TYPE "int-range-analysis"
30 
31 //===----------------------------------------------------------------------===//
32 // General utilities
33 //===----------------------------------------------------------------------===//
34 
35 /// Function that evaluates the result of doing something on arithmetic
36 /// constants and returns std::nullopt on overflow.
37 using ConstArithFn =
38  function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
40  std::function<std::optional<APInt>(const APInt &, const APInt &)>;
41 
42 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
43 /// If either computation overflows, make the result unbounded.
44 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
45  const APInt &minRight,
46  const APInt &maxLeft,
47  const APInt &maxRight, bool isSigned) {
48  std::optional<APInt> maybeMin = op(minLeft, minRight);
49  std::optional<APInt> maybeMax = op(maxLeft, maxRight);
50  if (maybeMin && maybeMax)
51  return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
52  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
53 }
54 
55 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
56 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
58  ArrayRef<APInt> rhs, bool isSigned) {
59  unsigned width = lhs[0].getBitWidth();
60  APInt min =
61  isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
62  APInt max =
63  isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
64  for (const APInt &left : lhs) {
65  for (const APInt &right : rhs) {
66  std::optional<APInt> maybeThisResult = op(left, right);
67  if (!maybeThisResult)
68  return ConstantIntRanges::maxRange(width);
69  APInt result = std::move(*maybeThisResult);
70  min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
71  max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
72  }
73  }
74  return ConstantIntRanges::range(min, max, isSigned);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Ext, trunc, index op handling
79 //===----------------------------------------------------------------------===//
80 
84  intrange::CmpMode mode) {
85  ConstantIntRanges sixtyFour = inferFn(argRanges);
87  llvm::transform(argRanges, std::back_inserter(truncated),
88  [](const ConstantIntRanges &range) {
89  return truncRange(range, /*destWidth=*/indexMinWidth);
90  });
91  ConstantIntRanges thirtyTwo = inferFn(truncated);
92  ConstantIntRanges thirtyTwoAsSixtyFour =
93  extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
94  ConstantIntRanges sixtyFourAsThirtyTwo =
95  truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
96 
97  LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
98  << " 32-bit = " << thirtyTwo << "\n");
99  bool truncEqual = false;
100  switch (mode) {
102  truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
103  break;
105  truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
106  thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
107  break;
109  truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
110  thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
111  break;
112  }
113  if (truncEqual)
114  // Returing the 64-bit result preserves more information.
115  return sixtyFour;
116  ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
117  return merged;
118 }
119 
121  unsigned int destWidth) {
122  APInt umin = range.umin().zext(destWidth);
123  APInt umax = range.umax().zext(destWidth);
124  APInt smin = range.smin().sext(destWidth);
125  APInt smax = range.smax().sext(destWidth);
126  return {umin, umax, smin, smax};
127 }
128 
130  unsigned destWidth) {
131  APInt umin = range.umin().zext(destWidth);
132  APInt umax = range.umax().zext(destWidth);
133  return ConstantIntRanges::fromUnsigned(umin, umax);
134 }
135 
137  unsigned destWidth) {
138  APInt smin = range.smin().sext(destWidth);
139  APInt smax = range.smax().sext(destWidth);
140  return ConstantIntRanges::fromSigned(smin, smax);
141 }
142 
144  unsigned int destWidth) {
145  // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
146  // the range of the resulting value is not contiguous ind includes 0.
147  // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
148  // but you can't truncate [255, 257] similarly.
149  bool hasUnsignedRollover =
150  range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
151  APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
152  : range.umin().trunc(destWidth);
153  APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
154  : range.umax().trunc(destWidth);
155 
156  // Signed post-truncation rollover will not occur when either:
157  // - The high parts of the min and max, plus the sign bit, are the same
158  // - The high halves + sign bit of the min and max are either all 1s or all 0s
159  // and you won't create a [positive, negative] range by truncating.
160  // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
161  // but not [255, 257]_i16 to a range of i8s. You can also truncate
162  // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
163  // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
164  // will truncate to 0x7e, which is greater than 0
165  APInt sminHighPart = range.smin().ashr(destWidth - 1);
166  APInt smaxHighPart = range.smax().ashr(destWidth - 1);
167  bool hasSignedOverflow =
168  (sminHighPart != smaxHighPart) &&
169  !(sminHighPart.isAllOnes() &&
170  (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
171  !(sminHighPart.isZero() && smaxHighPart.isZero());
172  APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
173  : range.smin().trunc(destWidth);
174  APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
175  : range.smax().trunc(destWidth);
176  return {umin, umax, smin, smax};
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Addition
181 //===----------------------------------------------------------------------===//
182 
185  OverflowFlags ovfFlags) {
186  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
187 
188  ConstArithStdFn uadd = [=](const APInt &a,
189  const APInt &b) -> std::optional<APInt> {
190  bool overflowed = false;
191  APInt result = any(ovfFlags & OverflowFlags::Nuw)
192  ? a.uadd_sat(b)
193  : a.uadd_ov(b, overflowed);
194  return overflowed ? std::optional<APInt>() : result;
195  };
196  ConstArithStdFn sadd = [=](const APInt &a,
197  const APInt &b) -> std::optional<APInt> {
198  bool overflowed = false;
199  APInt result = any(ovfFlags & OverflowFlags::Nsw)
200  ? a.sadd_sat(b)
201  : a.sadd_ov(b, overflowed);
202  return overflowed ? std::optional<APInt>() : result;
203  };
204 
206  uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
208  sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
209  return urange.intersection(srange);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Subtraction
214 //===----------------------------------------------------------------------===//
215 
218  OverflowFlags ovfFlags) {
219  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
220 
221  ConstArithStdFn usub = [=](const APInt &a,
222  const APInt &b) -> std::optional<APInt> {
223  bool overflowed = false;
224  APInt result = any(ovfFlags & OverflowFlags::Nuw)
225  ? a.usub_sat(b)
226  : a.usub_ov(b, overflowed);
227  return overflowed ? std::optional<APInt>() : result;
228  };
229  ConstArithStdFn ssub = [=](const APInt &a,
230  const APInt &b) -> std::optional<APInt> {
231  bool overflowed = false;
232  APInt result = any(ovfFlags & OverflowFlags::Nsw)
233  ? a.ssub_sat(b)
234  : a.ssub_ov(b, overflowed);
235  return overflowed ? std::optional<APInt>() : result;
236  };
238  usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
240  ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
241  return urange.intersection(srange);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Multiplication
246 //===----------------------------------------------------------------------===//
247 
250  OverflowFlags ovfFlags) {
251  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
252 
253  ConstArithStdFn umul = [=](const APInt &a,
254  const APInt &b) -> std::optional<APInt> {
255  bool overflowed = false;
256  APInt result = any(ovfFlags & OverflowFlags::Nuw)
257  ? a.umul_sat(b)
258  : a.umul_ov(b, overflowed);
259  return overflowed ? std::optional<APInt>() : result;
260  };
261  ConstArithStdFn smul = [=](const APInt &a,
262  const APInt &b) -> std::optional<APInt> {
263  bool overflowed = false;
264  APInt result = any(ovfFlags & OverflowFlags::Nsw)
265  ? a.smul_sat(b)
266  : a.smul_ov(b, overflowed);
267  return overflowed ? std::optional<APInt>() : result;
268  };
269 
270  ConstantIntRanges urange =
271  minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
272  /*isSigned=*/false);
273  ConstantIntRanges srange =
274  minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
275  /*isSigned=*/true);
276  return urange.intersection(srange);
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // DivU, CeilDivU (Unsigned division)
281 //===----------------------------------------------------------------------===//
282 
283 /// Fix up division results (ex. for ceiling and floor), returning an APInt
284 /// if there has been no overflow
286  const APInt &lhs, const APInt &rhs, const APInt &result)>;
287 
289  const ConstantIntRanges &rhs,
290  DivisionFixupFn fixup) {
291  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
292  &rhsMax = rhs.umax();
293  if (!rhsMin.isZero() && !rhsMax.isZero()) {
294  auto udiv = [&fixup](const APInt &a,
295  const APInt &b) -> std::optional<APInt> {
296  return fixup(a, b, a.udiv(b));
297  };
298  return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
299  /*isSigned=*/false);
300  }
301 
302  APInt umin = APInt::getZero(rhsMin.getBitWidth());
303  if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
304  umin = lhsMin.udiv(rhsMax);
305 
306  // X u/ Y u<= X.
307  APInt umax = lhsMax;
308  return ConstantIntRanges::fromUnsigned(umin, umax);
309 }
310 
313  return inferDivURange(argRanges[0], argRanges[1],
314  [](const APInt &lhs, const APInt &rhs,
315  const APInt &result) { return result; });
316 }
317 
320  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
321 
322  auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
323  const APInt &result) -> std::optional<APInt> {
324  if (!lhs.urem(rhs).isZero()) {
325  bool overflowed = false;
326  APInt corrected =
327  result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
328  return overflowed ? std::optional<APInt>() : corrected;
329  }
330  return result;
331  };
332  return inferDivURange(lhs, rhs, ceilDivUIFix);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // DivS, CeilDivS, FloorDivS (Signed division)
337 //===----------------------------------------------------------------------===//
338 
340  const ConstantIntRanges &rhs,
341  DivisionFixupFn fixup) {
342  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
343  &rhsMax = rhs.smax();
344  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
345 
346  if (canDivide) {
347  auto sdiv = [&fixup](const APInt &a,
348  const APInt &b) -> std::optional<APInt> {
349  bool overflowed = false;
350  APInt result = a.sdiv_ov(b, overflowed);
351  return overflowed ? std::optional<APInt>() : fixup(a, b, result);
352  };
353  return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
354  /*isSigned=*/true);
355  }
356  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
357 }
358 
361  return inferDivSRange(argRanges[0], argRanges[1],
362  [](const APInt &lhs, const APInt &rhs,
363  const APInt &result) { return result; });
364 }
365 
368  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
369 
370  auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
371  const APInt &result) -> std::optional<APInt> {
372  if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
373  bool overflowed = false;
374  APInt corrected =
375  result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
376  return overflowed ? std::optional<APInt>() : corrected;
377  }
378  // Special case where the usual implementation of ceilDiv causes
379  // INT_MIN / [positive number] to be positive. This doesn't match the
380  // definition of signed ceiling division mathematically, but it prevents
381  // inconsistent constant-folding results. This arises because (-int_min) is
382  // still negative, so -(-int_min / b) is -(int_min / b), which is
383  // positive See #115293.
384  if (lhs.isMinSignedValue() && rhs.sgt(1)) {
385  return -result;
386  }
387  return result;
388  };
389  ConstantIntRanges result = inferDivSRange(lhs, rhs, ceilDivSIFix);
390  if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) {
391  // If lhs range includes INT_MIN and lhs is not a single value, we can
392  // suddenly wrap to positive val, skipping entire negative range, add
393  // [INT_MIN + 1, smax()] range to the result to handle this.
394  auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
395  result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
396  }
397  return result;
398 }
399 
402  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
403 
404  auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
405  const APInt &result) -> std::optional<APInt> {
406  if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
407  bool overflowed = false;
408  APInt corrected =
409  result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
410  return overflowed ? std::optional<APInt>() : corrected;
411  }
412  return result;
413  };
414  return inferDivSRange(lhs, rhs, floorDivSIFix);
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // Signed remainder (RemS)
419 //===----------------------------------------------------------------------===//
420 
423  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
424  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
425  &rhsMax = rhs.smax();
426 
427  unsigned width = rhsMax.getBitWidth();
428  APInt smin = APInt::getSignedMinValue(width);
429  APInt smax = APInt::getSignedMaxValue(width);
430  // No bounds if zero could be a divisor.
431  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
432  if (canBound) {
433  APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
434  bool canNegativeDividend = lhsMin.isNegative();
435  bool canPositiveDividend = lhsMax.isStrictlyPositive();
436  APInt zero = APInt::getZero(maxDivisor.getBitWidth());
437  APInt maxPositiveResult = maxDivisor - 1;
438  APInt minNegativeResult = -maxPositiveResult;
439  smin = canNegativeDividend ? minNegativeResult : zero;
440  smax = canPositiveDividend ? maxPositiveResult : zero;
441  // Special case: sweeping out a contiguous range in N/[modulus].
442  if (rhsMin == rhsMax) {
443  if ((lhsMax - lhsMin).ult(maxDivisor)) {
444  APInt minRem = lhsMin.srem(maxDivisor);
445  APInt maxRem = lhsMax.srem(maxDivisor);
446  if (minRem.sle(maxRem)) {
447  smin = minRem;
448  smax = maxRem;
449  }
450  }
451  }
452  }
453  return ConstantIntRanges::fromSigned(smin, smax);
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // Unsigned remainder (RemU)
458 //===----------------------------------------------------------------------===//
459 
462  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
463  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
464 
465  unsigned width = rhsMin.getBitWidth();
466  APInt umin = APInt::getZero(width);
467  // Remainder can't be larger than either of its arguments.
468  APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
469 
470  if (!rhsMin.isZero()) {
471  // Special case: sweeping out a contiguous range in N/[modulus]
472  if (rhsMin == rhsMax) {
473  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
474  if ((lhsMax - lhsMin).ult(rhsMax)) {
475  APInt minRem = lhsMin.urem(rhsMax);
476  APInt maxRem = lhsMax.urem(rhsMax);
477  if (minRem.ule(maxRem)) {
478  umin = minRem;
479  umax = maxRem;
480  }
481  }
482  }
483  }
484  return ConstantIntRanges::fromUnsigned(umin, umax);
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // Max and min (MaxS, MaxU, MinS, MinU)
489 //===----------------------------------------------------------------------===//
490 
493  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
494 
495  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
496  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
497  return ConstantIntRanges::fromSigned(smin, smax);
498 }
499 
502  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
503 
504  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
505  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
506  return ConstantIntRanges::fromUnsigned(umin, umax);
507 }
508 
511  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
512 
513  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
514  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
515  return ConstantIntRanges::fromSigned(smin, smax);
516 }
517 
520  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
521 
522  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
523  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
524  return ConstantIntRanges::fromUnsigned(umin, umax);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // Bitwise operators (And, Or, Xor)
529 //===----------------------------------------------------------------------===//
530 
531 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
532 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
533 /// that both bonuds have in common. This gives us a consertive approximation
534 /// for what values can be passed to bitwise operations.
535 static std::tuple<APInt, APInt>
537  APInt leftVal = bound.umin(), rightVal = bound.umax();
538  unsigned bitwidth = leftVal.getBitWidth();
539  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
540  leftVal.clearLowBits(differingBits);
541  rightVal.setLowBits(differingBits);
542  return std::make_tuple(std::move(leftVal), std::move(rightVal));
543 }
544 
547  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
548  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
549  auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
550  return a & b;
551  };
552  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
553  /*isSigned=*/false);
554 }
555 
558  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
559  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
560  auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
561  return a | b;
562  };
563  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
564  /*isSigned=*/false);
565 }
566 
567 /// Get bitmask of all bits which can change while iterating in
568 /// [bound.umin(), bound.umax()].
569 static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
570  APInt leftVal = bound.umin(), rightVal = bound.umax();
571  unsigned bitwidth = leftVal.getBitWidth();
572  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
573  return APInt::getLowBitsSet(bitwidth, differingBits);
574 }
575 
578  // Construct mask of varying bits for both ranges, xor values and then replace
579  // masked bits with 0s and 1s to get min and max values respectively.
580  ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
581  APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
582  APInt res = lhs.umin() ^ rhs.umin();
583  APInt min = res & ~mask;
584  APInt max = res | mask;
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // Shifts (Shl, ShrS, ShrU)
590 //===----------------------------------------------------------------------===//
591 
594  OverflowFlags ovfFlags) {
595  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
596  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
597 
598  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
599  // 2^rhs.
600  ConstArithStdFn ushl = [=](const APInt &l,
601  const APInt &r) -> std::optional<APInt> {
602  bool overflowed = false;
603  APInt result = any(ovfFlags & OverflowFlags::Nuw)
604  ? l.ushl_sat(r)
605  : l.ushl_ov(r, overflowed);
606  return overflowed ? std::optional<APInt>() : result;
607  };
608  ConstArithStdFn sshl = [=](const APInt &l,
609  const APInt &r) -> std::optional<APInt> {
610  bool overflowed = false;
611  APInt result = any(ovfFlags & OverflowFlags::Nsw)
612  ? l.sshl_sat(r)
613  : l.sshl_ov(r, overflowed);
614  return overflowed ? std::optional<APInt>() : result;
615  };
616 
617  ConstantIntRanges urange =
618  minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
619  /*isSigned=*/false);
620  ConstantIntRanges srange =
621  minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
622  /*isSigned=*/true);
623  return urange.intersection(srange);
624 }
625 
628  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
629 
630  auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
631  return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
632  };
633 
634  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
635  /*isSigned=*/true);
636 }
637 
640  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
641 
642  auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
643  return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
644  };
645  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
646  /*isSigned=*/false);
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // Comparisons (Cmp)
651 //===----------------------------------------------------------------------===//
652 
654  switch (pred) {
675  }
676  llvm_unreachable("unknown cmp predicate value");
677 }
678 
680  const ConstantIntRanges &lhs,
681  const ConstantIntRanges &rhs) {
682  switch (pred) {
684  return lhs.smax().sle(rhs.smin());
686  return lhs.smax().slt(rhs.smin());
688  return lhs.umax().ule(rhs.umin());
690  return lhs.umax().ult(rhs.umin());
692  return lhs.smin().sge(rhs.smax());
694  return lhs.smin().sgt(rhs.smax());
696  return lhs.umin().uge(rhs.umax());
698  return lhs.umin().ugt(rhs.umax());
700  std::optional<APInt> lhsConst = lhs.getConstantValue();
701  std::optional<APInt> rhsConst = rhs.getConstantValue();
702  return lhsConst && rhsConst && lhsConst == rhsConst;
703  }
705  // While equality requires that there is an interpration of the preceeding
706  // computations that produces equal constants, whether that be signed or
707  // unsigned, statically determining inequality requires that neither
708  // interpretation produce potentially overlapping ranges.
709  bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
711  bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
713  return sne && une;
714  }
715  }
716  return false;
717 }
718 
720  const ConstantIntRanges &lhs,
721  const ConstantIntRanges &rhs) {
722  if (isStaticallyTrue(pred, lhs, rhs))
723  return true;
724  if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
725  return false;
726  return std::nullopt;
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // Shaped type dimension accessors / ShapedDimOpInterface
731 //===----------------------------------------------------------------------===//
732 
735  const IntegerValueRange &maybeDim) {
736  unsigned width =
737  ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
738  APInt zero = APInt::getZero(width);
739  APInt typeMax = APInt::getSignedMaxValue(width);
740 
741  auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
742  if (!shapedTy.hasRank())
743  return ConstantIntRanges::fromSigned(zero, typeMax);
744 
745  int64_t rank = shapedTy.getRank();
746  int64_t minDim = 0;
747  int64_t maxDim = rank - 1;
748  if (!maybeDim.isUninitialized()) {
749  const ConstantIntRanges &dim = maybeDim.getValue();
750  minDim = std::max(minDim, dim.smin().getSExtValue());
751  maxDim = std::min(maxDim, dim.smax().getSExtValue());
752  }
753 
754  std::optional<ConstantIntRanges> result;
755  auto joinResult = [&](const ConstantIntRanges &thisResult) {
756  if (!result.has_value())
757  result = thisResult;
758  else
759  result = result->rangeUnion(thisResult);
760  };
761  for (int64_t i = minDim; i <= maxDim; ++i) {
762  int64_t length = shapedTy.getDimSize(i);
763 
764  if (ShapedType::isDynamic(length))
765  joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
766  else
767  joinResult(ConstantIntRanges::constant(APInt(width, length)));
768  }
769  return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
770 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111,...
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.