MLIR  21.0.0git
InferIntRangeCommon.cpp
Go to the documentation of this file.
1 //===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains implementations of range inference for operations that are
10 // common to both the `arith` and `index` dialects to facilitate reuse.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 #include "llvm/Support/Debug.h"
23 
24 #include <iterator>
25 #include <optional>
26 
27 using namespace mlir;
28 
29 #define DEBUG_TYPE "int-range-analysis"
30 
31 //===----------------------------------------------------------------------===//
32 // General utilities
33 //===----------------------------------------------------------------------===//
34 
35 /// Function that evaluates the result of doing something on arithmetic
36 /// constants and returns std::nullopt on overflow.
37 using ConstArithFn =
38  function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
40  std::function<std::optional<APInt>(const APInt &, const APInt &)>;
41 
42 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
43 /// If either computation overflows, make the result unbounded.
44 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
45  const APInt &minRight,
46  const APInt &maxLeft,
47  const APInt &maxRight, bool isSigned) {
48  std::optional<APInt> maybeMin = op(minLeft, minRight);
49  std::optional<APInt> maybeMax = op(maxLeft, maxRight);
50  if (maybeMin && maybeMax)
51  return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
52  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
53 }
54 
55 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
56 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
58  ArrayRef<APInt> rhs, bool isSigned) {
59  unsigned width = lhs[0].getBitWidth();
60  APInt min =
61  isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
62  APInt max =
63  isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
64  for (const APInt &left : lhs) {
65  for (const APInt &right : rhs) {
66  std::optional<APInt> maybeThisResult = op(left, right);
67  if (!maybeThisResult)
68  return ConstantIntRanges::maxRange(width);
69  APInt result = std::move(*maybeThisResult);
70  min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
71  max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
72  }
73  }
74  return ConstantIntRanges::range(min, max, isSigned);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Ext, trunc, index op handling
79 //===----------------------------------------------------------------------===//
80 
84  intrange::CmpMode mode) {
85  ConstantIntRanges sixtyFour = inferFn(argRanges);
87  llvm::transform(argRanges, std::back_inserter(truncated),
88  [](const ConstantIntRanges &range) {
89  return truncRange(range, /*destWidth=*/indexMinWidth);
90  });
91  ConstantIntRanges thirtyTwo = inferFn(truncated);
92  ConstantIntRanges thirtyTwoAsSixtyFour =
93  extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
94  ConstantIntRanges sixtyFourAsThirtyTwo =
95  truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
96 
97  LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
98  << " 32-bit = " << thirtyTwo << "\n");
99  bool truncEqual = false;
100  switch (mode) {
102  truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
103  break;
105  truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
106  thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
107  break;
109  truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
110  thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
111  break;
112  }
113  if (truncEqual)
114  // Returing the 64-bit result preserves more information.
115  return sixtyFour;
116  ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
117  return merged;
118 }
119 
121  unsigned int destWidth) {
122  APInt umin = range.umin().zext(destWidth);
123  APInt umax = range.umax().zext(destWidth);
124  APInt smin = range.smin().sext(destWidth);
125  APInt smax = range.smax().sext(destWidth);
126  return {umin, umax, smin, smax};
127 }
128 
130  unsigned destWidth) {
131  APInt umin = range.umin().zext(destWidth);
132  APInt umax = range.umax().zext(destWidth);
133  return ConstantIntRanges::fromUnsigned(umin, umax);
134 }
135 
137  unsigned destWidth) {
138  APInt smin = range.smin().sext(destWidth);
139  APInt smax = range.smax().sext(destWidth);
140  return ConstantIntRanges::fromSigned(smin, smax);
141 }
142 
144  unsigned int destWidth) {
145  // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
146  // the range of the resulting value is not contiguous ind includes 0.
147  // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
148  // but you can't truncate [255, 257] similarly.
149  bool hasUnsignedRollover =
150  range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
151  APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
152  : range.umin().trunc(destWidth);
153  APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
154  : range.umax().trunc(destWidth);
155 
156  // Signed post-truncation rollover will not occur when either:
157  // - The high parts of the min and max, plus the sign bit, are the same
158  // - The high halves + sign bit of the min and max are either all 1s or all 0s
159  // and you won't create a [positive, negative] range by truncating.
160  // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
161  // but not [255, 257]_i16 to a range of i8s. You can also truncate
162  // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
163  // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
164  // will truncate to 0x7e, which is greater than 0
165  APInt sminHighPart = range.smin().ashr(destWidth - 1);
166  APInt smaxHighPart = range.smax().ashr(destWidth - 1);
167  bool hasSignedOverflow =
168  (sminHighPart != smaxHighPart) &&
169  !(sminHighPart.isAllOnes() &&
170  (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
171  !(sminHighPart.isZero() && smaxHighPart.isZero());
172  APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
173  : range.smin().trunc(destWidth);
174  APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
175  : range.smax().trunc(destWidth);
176  return {umin, umax, smin, smax};
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Addition
181 //===----------------------------------------------------------------------===//
182 
185  OverflowFlags ovfFlags) {
186  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
187 
188  ConstArithStdFn uadd = [=](const APInt &a,
189  const APInt &b) -> std::optional<APInt> {
190  bool overflowed = false;
191  APInt result = any(ovfFlags & OverflowFlags::Nuw)
192  ? a.uadd_sat(b)
193  : a.uadd_ov(b, overflowed);
194  return overflowed ? std::optional<APInt>() : result;
195  };
196  ConstArithStdFn sadd = [=](const APInt &a,
197  const APInt &b) -> std::optional<APInt> {
198  bool overflowed = false;
199  APInt result = any(ovfFlags & OverflowFlags::Nsw)
200  ? a.sadd_sat(b)
201  : a.sadd_ov(b, overflowed);
202  return overflowed ? std::optional<APInt>() : result;
203  };
204 
206  uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
208  sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
209  return urange.intersection(srange);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Subtraction
214 //===----------------------------------------------------------------------===//
215 
218  OverflowFlags ovfFlags) {
219  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
220 
221  ConstArithStdFn usub = [=](const APInt &a,
222  const APInt &b) -> std::optional<APInt> {
223  bool overflowed = false;
224  APInt result = any(ovfFlags & OverflowFlags::Nuw)
225  ? a.usub_sat(b)
226  : a.usub_ov(b, overflowed);
227  return overflowed ? std::optional<APInt>() : result;
228  };
229  ConstArithStdFn ssub = [=](const APInt &a,
230  const APInt &b) -> std::optional<APInt> {
231  bool overflowed = false;
232  APInt result = any(ovfFlags & OverflowFlags::Nsw)
233  ? a.ssub_sat(b)
234  : a.ssub_ov(b, overflowed);
235  return overflowed ? std::optional<APInt>() : result;
236  };
238  usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
240  ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
241  return urange.intersection(srange);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Multiplication
246 //===----------------------------------------------------------------------===//
247 
250  OverflowFlags ovfFlags) {
251  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
252 
253  ConstArithStdFn umul = [=](const APInt &a,
254  const APInt &b) -> std::optional<APInt> {
255  bool overflowed = false;
256  APInt result = any(ovfFlags & OverflowFlags::Nuw)
257  ? a.umul_sat(b)
258  : a.umul_ov(b, overflowed);
259  return overflowed ? std::optional<APInt>() : result;
260  };
261  ConstArithStdFn smul = [=](const APInt &a,
262  const APInt &b) -> std::optional<APInt> {
263  bool overflowed = false;
264  APInt result = any(ovfFlags & OverflowFlags::Nsw)
265  ? a.smul_sat(b)
266  : a.smul_ov(b, overflowed);
267  return overflowed ? std::optional<APInt>() : result;
268  };
269 
270  ConstantIntRanges urange =
271  minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
272  /*isSigned=*/false);
273  ConstantIntRanges srange =
274  minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
275  /*isSigned=*/true);
276  return urange.intersection(srange);
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // DivU, CeilDivU (Unsigned division)
281 //===----------------------------------------------------------------------===//
282 
283 /// Fix up division results (ex. for ceiling and floor), returning an APInt
284 /// if there has been no overflow
286  const APInt &lhs, const APInt &rhs, const APInt &result)>;
287 
289  const ConstantIntRanges &rhs,
290  DivisionFixupFn fixup) {
291  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
292  &rhsMax = rhs.umax();
293 
294  if (!rhsMin.isZero()) {
295  auto udiv = [&fixup](const APInt &a,
296  const APInt &b) -> std::optional<APInt> {
297  return fixup(a, b, a.udiv(b));
298  };
299  return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
300  /*isSigned=*/false);
301  }
302 
303  APInt umin = APInt::getZero(rhsMin.getBitWidth());
304  if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
305  umin = lhsMin.udiv(rhsMax);
306 
307  // X u/ Y u<= X.
308  APInt umax = lhsMax;
309  return ConstantIntRanges::fromUnsigned(umin, umax);
310 }
311 
314  return inferDivURange(argRanges[0], argRanges[1],
315  [](const APInt &lhs, const APInt &rhs,
316  const APInt &result) { return result; });
317 }
318 
321  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
322 
323  auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
324  const APInt &result) -> std::optional<APInt> {
325  if (!lhs.urem(rhs).isZero()) {
326  bool overflowed = false;
327  APInt corrected =
328  result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
329  return overflowed ? std::optional<APInt>() : corrected;
330  }
331  return result;
332  };
333  return inferDivURange(lhs, rhs, ceilDivUIFix);
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // DivS, CeilDivS, FloorDivS (Signed division)
338 //===----------------------------------------------------------------------===//
339 
341  const ConstantIntRanges &rhs,
342  DivisionFixupFn fixup) {
343  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
344  &rhsMax = rhs.smax();
345  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
346 
347  if (canDivide) {
348  auto sdiv = [&fixup](const APInt &a,
349  const APInt &b) -> std::optional<APInt> {
350  bool overflowed = false;
351  APInt result = a.sdiv_ov(b, overflowed);
352  return overflowed ? std::optional<APInt>() : fixup(a, b, result);
353  };
354  return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
355  /*isSigned=*/true);
356  }
357  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
358 }
359 
362  return inferDivSRange(argRanges[0], argRanges[1],
363  [](const APInt &lhs, const APInt &rhs,
364  const APInt &result) { return result; });
365 }
366 
369  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
370 
371  auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
372  const APInt &result) -> std::optional<APInt> {
373  if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
374  bool overflowed = false;
375  APInt corrected =
376  result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
377  return overflowed ? std::optional<APInt>() : corrected;
378  }
379  // Special case where the usual implementation of ceilDiv causes
380  // INT_MIN / [positive number] to be positive. This doesn't match the
381  // definition of signed ceiling division mathematically, but it prevents
382  // inconsistent constant-folding results. This arises because (-int_min) is
383  // still negative, so -(-int_min / b) is -(int_min / b), which is
384  // positive See #115293.
385  if (lhs.isMinSignedValue() && rhs.sgt(1)) {
386  return -result;
387  }
388  return result;
389  };
390  ConstantIntRanges result = inferDivSRange(lhs, rhs, ceilDivSIFix);
391  if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) {
392  // If lhs range includes INT_MIN and lhs is not a single value, we can
393  // suddenly wrap to positive val, skipping entire negative range, add
394  // [INT_MIN + 1, smax()] range to the result to handle this.
395  auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
396  result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
397  }
398  return result;
399 }
400 
403  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
404 
405  auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
406  const APInt &result) -> std::optional<APInt> {
407  if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
408  bool overflowed = false;
409  APInt corrected =
410  result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
411  return overflowed ? std::optional<APInt>() : corrected;
412  }
413  return result;
414  };
415  return inferDivSRange(lhs, rhs, floorDivSIFix);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // Signed remainder (RemS)
420 //===----------------------------------------------------------------------===//
421 
424  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
425  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
426  &rhsMax = rhs.smax();
427 
428  unsigned width = rhsMax.getBitWidth();
429  APInt smin = APInt::getSignedMinValue(width);
430  APInt smax = APInt::getSignedMaxValue(width);
431  // No bounds if zero could be a divisor.
432  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
433  if (canBound) {
434  APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
435  bool canNegativeDividend = lhsMin.isNegative();
436  bool canPositiveDividend = lhsMax.isStrictlyPositive();
437  APInt zero = APInt::getZero(maxDivisor.getBitWidth());
438  APInt maxPositiveResult = maxDivisor - 1;
439  APInt minNegativeResult = -maxPositiveResult;
440  smin = canNegativeDividend ? minNegativeResult : zero;
441  smax = canPositiveDividend ? maxPositiveResult : zero;
442  // Special case: sweeping out a contiguous range in N/[modulus].
443  if (rhsMin == rhsMax) {
444  if ((lhsMax - lhsMin).ult(maxDivisor)) {
445  APInt minRem = lhsMin.srem(maxDivisor);
446  APInt maxRem = lhsMax.srem(maxDivisor);
447  if (minRem.sle(maxRem)) {
448  smin = minRem;
449  smax = maxRem;
450  }
451  }
452  }
453  }
454  return ConstantIntRanges::fromSigned(smin, smax);
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // Unsigned remainder (RemU)
459 //===----------------------------------------------------------------------===//
460 
463  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
464  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
465 
466  unsigned width = rhsMin.getBitWidth();
467  APInt umin = APInt::getZero(width);
468  // Remainder can't be larger than either of its arguments.
469  APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
470 
471  if (!rhsMin.isZero()) {
472  // Special case: sweeping out a contiguous range in N/[modulus]
473  if (rhsMin == rhsMax) {
474  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
475  if ((lhsMax - lhsMin).ult(rhsMax)) {
476  APInt minRem = lhsMin.urem(rhsMax);
477  APInt maxRem = lhsMax.urem(rhsMax);
478  if (minRem.ule(maxRem)) {
479  umin = minRem;
480  umax = maxRem;
481  }
482  }
483  }
484  }
485  return ConstantIntRanges::fromUnsigned(umin, umax);
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // Max and min (MaxS, MaxU, MinS, MinU)
490 //===----------------------------------------------------------------------===//
491 
494  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
495 
496  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
497  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
498  return ConstantIntRanges::fromSigned(smin, smax);
499 }
500 
503  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
504 
505  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
506  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
507  return ConstantIntRanges::fromUnsigned(umin, umax);
508 }
509 
512  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
513 
514  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
515  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
516  return ConstantIntRanges::fromSigned(smin, smax);
517 }
518 
521  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
522 
523  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
524  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
525  return ConstantIntRanges::fromUnsigned(umin, umax);
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // Bitwise operators (And, Or, Xor)
530 //===----------------------------------------------------------------------===//
531 
532 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
533 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
534 /// that both bonuds have in common. This gives us a consertive approximation
535 /// for what values can be passed to bitwise operations.
536 static std::tuple<APInt, APInt>
538  APInt leftVal = bound.umin(), rightVal = bound.umax();
539  unsigned bitwidth = leftVal.getBitWidth();
540  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
541  leftVal.clearLowBits(differingBits);
542  rightVal.setLowBits(differingBits);
543  return std::make_tuple(std::move(leftVal), std::move(rightVal));
544 }
545 
548  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
549  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
550  auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
551  return a & b;
552  };
553  return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
554  /*isSigned=*/false);
555 }
556 
559  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
560  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
561  auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
562  return a | b;
563  };
564  return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
565  /*isSigned=*/false);
566 }
567 
568 /// Get bitmask of all bits which can change while iterating in
569 /// [bound.umin(), bound.umax()].
570 static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
571  APInt leftVal = bound.umin(), rightVal = bound.umax();
572  unsigned bitwidth = leftVal.getBitWidth();
573  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
574  return APInt::getLowBitsSet(bitwidth, differingBits);
575 }
576 
579  // Construct mask of varying bits for both ranges, xor values and then replace
580  // masked bits with 0s and 1s to get min and max values respectively.
581  ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
582  APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
583  APInt res = lhs.umin() ^ rhs.umin();
584  APInt min = res & ~mask;
585  APInt max = res | mask;
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // Shifts (Shl, ShrS, ShrU)
591 //===----------------------------------------------------------------------===//
592 
595  OverflowFlags ovfFlags) {
596  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
597  const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
598 
599  // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
600  // 2^rhs.
601  ConstArithStdFn ushl = [=](const APInt &l,
602  const APInt &r) -> std::optional<APInt> {
603  bool overflowed = false;
604  APInt result = any(ovfFlags & OverflowFlags::Nuw)
605  ? l.ushl_sat(r)
606  : l.ushl_ov(r, overflowed);
607  return overflowed ? std::optional<APInt>() : result;
608  };
609  ConstArithStdFn sshl = [=](const APInt &l,
610  const APInt &r) -> std::optional<APInt> {
611  bool overflowed = false;
612  APInt result = any(ovfFlags & OverflowFlags::Nsw)
613  ? l.sshl_sat(r)
614  : l.sshl_ov(r, overflowed);
615  return overflowed ? std::optional<APInt>() : result;
616  };
617 
618  ConstantIntRanges urange =
619  minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
620  /*isSigned=*/false);
621  ConstantIntRanges srange =
622  minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
623  /*isSigned=*/true);
624  return urange.intersection(srange);
625 }
626 
629  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
630 
631  auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
632  return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
633  };
634 
635  return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
636  /*isSigned=*/true);
637 }
638 
641  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
642 
643  auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
644  return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
645  };
646  return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
647  /*isSigned=*/false);
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // Comparisons (Cmp)
652 //===----------------------------------------------------------------------===//
653 
655  switch (pred) {
676  }
677  llvm_unreachable("unknown cmp predicate value");
678 }
679 
681  const ConstantIntRanges &lhs,
682  const ConstantIntRanges &rhs) {
683  switch (pred) {
685  return lhs.smax().sle(rhs.smin());
687  return lhs.smax().slt(rhs.smin());
689  return lhs.umax().ule(rhs.umin());
691  return lhs.umax().ult(rhs.umin());
693  return lhs.smin().sge(rhs.smax());
695  return lhs.smin().sgt(rhs.smax());
697  return lhs.umin().uge(rhs.umax());
699  return lhs.umin().ugt(rhs.umax());
701  std::optional<APInt> lhsConst = lhs.getConstantValue();
702  std::optional<APInt> rhsConst = rhs.getConstantValue();
703  return lhsConst && rhsConst && lhsConst == rhsConst;
704  }
706  // While equality requires that there is an interpration of the preceeding
707  // computations that produces equal constants, whether that be signed or
708  // unsigned, statically determining inequality requires that neither
709  // interpretation produce potentially overlapping ranges.
710  bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) ||
712  bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) ||
714  return sne && une;
715  }
716  }
717  return false;
718 }
719 
721  const ConstantIntRanges &lhs,
722  const ConstantIntRanges &rhs) {
723  if (isStaticallyTrue(pred, lhs, rhs))
724  return true;
725  if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
726  return false;
727  return std::nullopt;
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // Shaped type dimension accessors / ShapedDimOpInterface
732 //===----------------------------------------------------------------------===//
733 
736  const IntegerValueRange &maybeDim) {
737  unsigned width =
738  ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
739  APInt zero = APInt::getZero(width);
740  APInt typeMax = APInt::getSignedMaxValue(width);
741 
742  auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
743  if (!shapedTy.hasRank())
744  return ConstantIntRanges::fromSigned(zero, typeMax);
745 
746  int64_t rank = shapedTy.getRank();
747  int64_t minDim = 0;
748  int64_t maxDim = rank - 1;
749  if (!maybeDim.isUninitialized()) {
750  const ConstantIntRanges &dim = maybeDim.getValue();
751  minDim = std::max(minDim, dim.smin().getSExtValue());
752  maxDim = std::min(maxDim, dim.smax().getSExtValue());
753  }
754 
755  std::optional<ConstantIntRanges> result;
756  auto joinResult = [&](const ConstantIntRanges &thisResult) {
757  if (!result.has_value())
758  result = thisResult;
759  else
760  result = result->rangeUnion(thisResult);
761  };
762  for (int64_t i = minDim; i <= maxDim; ++i) {
763  int64_t length = shapedTy.getDimSize(i);
764 
765  if (ShapedType::isDynamic(length))
766  joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
767  else
768  joinResult(ConstantIntRanges::constant(APInt(width, length)));
769  }
770  return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
771 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111,...
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.