19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "int-range-analysis"
40 std::function<std::optional<APInt>(
const APInt &,
const APInt &)>;
45 const APInt &minRight,
47 const APInt &maxRight,
bool isSigned) {
48 std::optional<APInt> maybeMin = op(minLeft, minRight);
49 std::optional<APInt> maybeMax = op(maxLeft, maxRight);
50 if (maybeMin && maybeMax)
59 unsigned width = lhs[0].getBitWidth();
61 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
63 isSigned ? APInt::getSignedMinValue(width) :
APInt::getZero(width);
64 for (
const APInt &left : lhs) {
65 for (
const APInt &right : rhs) {
66 std::optional<APInt> maybeThisResult = op(left, right);
69 APInt result = std::move(*maybeThisResult);
70 min = (isSigned ? result.slt(
min) : result.ult(
min)) ? result :
min;
71 max = (isSigned ? result.sgt(
max) : result.ugt(
max)) ? result :
max;
87 llvm::transform(argRanges, std::back_inserter(truncated),
97 LLVM_DEBUG(llvm::dbgs() <<
"Index handling: 64-bit result = " << sixtyFour
98 <<
" 32-bit = " << thirtyTwo <<
"\n");
99 bool truncEqual =
false;
102 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
105 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.
smin() &&
106 thirtyTwo.smax() == sixtyFourAsThirtyTwo.
smax());
109 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.
umin() &&
110 thirtyTwo.umax() == sixtyFourAsThirtyTwo.
umax());
121 unsigned int destWidth) {
122 APInt umin = range.
umin().zext(destWidth);
123 APInt umax = range.
umax().zext(destWidth);
124 APInt smin = range.
smin().sext(destWidth);
125 APInt smax = range.
smax().sext(destWidth);
126 return {umin, umax, smin, smax};
130 unsigned destWidth) {
131 APInt umin = range.
umin().zext(destWidth);
132 APInt umax = range.
umax().zext(destWidth);
137 unsigned destWidth) {
138 APInt smin = range.
smin().sext(destWidth);
139 APInt smax = range.
smax().sext(destWidth);
144 unsigned int destWidth) {
149 bool hasUnsignedRollover =
150 range.
umin().lshr(destWidth) != range.
umax().lshr(destWidth);
152 : range.umin().trunc(destWidth);
153 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
154 : range.umax().trunc(destWidth);
165 APInt sminHighPart = range.
smin().ashr(destWidth - 1);
166 APInt smaxHighPart = range.
smax().ashr(destWidth - 1);
167 bool hasSignedOverflow =
168 (sminHighPart != smaxHighPart) &&
169 !(sminHighPart.isAllOnes() &&
170 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
171 !(sminHighPart.isZero() && smaxHighPart.isZero());
172 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
173 : range.smin().trunc(destWidth);
174 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
175 : range.smax().trunc(destWidth);
176 return {umin, umax, smin, smax};
189 const APInt &b) -> std::optional<APInt> {
190 bool overflowed =
false;
191 APInt result = any(ovfFlags & OverflowFlags::Nuw)
193 : a.uadd_ov(b, overflowed);
194 return overflowed ? std::optional<APInt>() : result;
197 const APInt &b) -> std::optional<APInt> {
198 bool overflowed =
false;
199 APInt result = any(ovfFlags & OverflowFlags::Nsw)
201 : a.sadd_ov(b, overflowed);
202 return overflowed ? std::optional<APInt>() : result;
206 uadd, lhs.
umin(), rhs.umin(), lhs.
umax(), rhs.umax(),
false);
208 sadd, lhs.
smin(), rhs.smin(), lhs.
smax(), rhs.smax(),
true);
222 const APInt &b) -> std::optional<APInt> {
223 bool overflowed =
false;
224 APInt result = any(ovfFlags & OverflowFlags::Nuw)
226 : a.usub_ov(b, overflowed);
227 return overflowed ? std::optional<APInt>() : result;
230 const APInt &b) -> std::optional<APInt> {
231 bool overflowed =
false;
232 APInt result = any(ovfFlags & OverflowFlags::Nsw)
234 : a.ssub_ov(b, overflowed);
235 return overflowed ? std::optional<APInt>() : result;
238 usub, lhs.
umin(), rhs.umax(), lhs.
umax(), rhs.umin(),
false);
240 ssub, lhs.
smin(), rhs.smax(), lhs.
smax(), rhs.smin(),
true);
254 const APInt &b) -> std::optional<APInt> {
255 bool overflowed =
false;
256 APInt result = any(ovfFlags & OverflowFlags::Nuw)
258 : a.umul_ov(b, overflowed);
259 return overflowed ? std::optional<APInt>() : result;
262 const APInt &b) -> std::optional<APInt> {
263 bool overflowed =
false;
264 APInt result = any(ovfFlags & OverflowFlags::Nsw)
266 : a.smul_ov(b, overflowed);
267 return overflowed ? std::optional<APInt>() : result;
286 const APInt &lhs,
const APInt &rhs,
const APInt &result)>;
291 const APInt &lhsMin = lhs.
umin(), &lhsMax = lhs.
umax(), &rhsMin = rhs.
umin(),
292 &rhsMax = rhs.
umax();
294 if (!rhsMin.isZero()) {
295 auto udiv = [&fixup](
const APInt &a,
296 const APInt &b) -> std::optional<APInt> {
297 return fixup(a, b, a.udiv(b));
299 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
304 if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
305 umin = lhsMin.udiv(rhsMax);
315 [](
const APInt &lhs,
const APInt &rhs,
316 const APInt &result) {
return result; });
323 auto ceilDivUIFix = [](
const APInt &lhs,
const APInt &rhs,
324 const APInt &result) -> std::optional<APInt> {
325 if (!lhs.urem(rhs).isZero()) {
326 bool overflowed =
false;
328 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
329 return overflowed ? std::optional<APInt>() : corrected;
343 const APInt &lhsMin = lhs.
smin(), &lhsMax = lhs.
smax(), &rhsMin = rhs.
smin(),
344 &rhsMax = rhs.
smax();
345 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
348 auto sdiv = [&fixup](
const APInt &a,
349 const APInt &b) -> std::optional<APInt> {
350 bool overflowed =
false;
351 APInt result = a.sdiv_ov(b, overflowed);
352 return overflowed ? std::optional<APInt>() : fixup(a, b, result);
354 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
363 [](
const APInt &lhs,
const APInt &rhs,
364 const APInt &result) {
return result; });
371 auto ceilDivSIFix = [](
const APInt &lhs,
const APInt &rhs,
372 const APInt &result) -> std::optional<APInt> {
373 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
374 bool overflowed =
false;
376 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
377 return overflowed ? std::optional<APInt>() : corrected;
385 if (lhs.isMinSignedValue() && rhs.sgt(1)) {
391 if (lhs.
smin().isMinSignedValue() && lhs.
smax().sgt(lhs.
smin())) {
405 auto floorDivSIFix = [](
const APInt &lhs,
const APInt &rhs,
406 const APInt &result) -> std::optional<APInt> {
407 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
408 bool overflowed =
false;
410 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
411 return overflowed ? std::optional<APInt>() : corrected;
425 const APInt &lhsMin = lhs.
smin(), &lhsMax = lhs.
smax(), &rhsMin = rhs.smin(),
426 &rhsMax = rhs.smax();
428 unsigned width = rhsMax.getBitWidth();
429 APInt smin = APInt::getSignedMinValue(width);
430 APInt smax = APInt::getSignedMaxValue(width);
432 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
434 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
435 bool canNegativeDividend = lhsMin.isNegative();
436 bool canPositiveDividend = lhsMax.isStrictlyPositive();
438 APInt maxPositiveResult = maxDivisor - 1;
439 APInt minNegativeResult = -maxPositiveResult;
440 smin = canNegativeDividend ? minNegativeResult : zero;
441 smax = canPositiveDividend ? maxPositiveResult : zero;
443 if (rhsMin == rhsMax) {
444 if ((lhsMax - lhsMin).
ult(maxDivisor)) {
445 APInt minRem = lhsMin.srem(maxDivisor);
446 APInt maxRem = lhsMax.srem(maxDivisor);
447 if (minRem.sle(maxRem)) {
464 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
466 unsigned width = rhsMin.getBitWidth();
469 APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.
umax());
471 if (!rhsMin.isZero()) {
473 if (rhsMin == rhsMax) {
474 const APInt &lhsMin = lhs.
umin(), &lhsMax = lhs.
umax();
475 if ((lhsMax - lhsMin).
ult(rhsMax)) {
476 APInt minRem = lhsMin.urem(rhsMax);
477 APInt maxRem = lhsMax.urem(rhsMax);
478 if (minRem.ule(maxRem)) {
496 const APInt &smin = lhs.
smin().sgt(rhs.smin()) ? lhs.
smin() : rhs.smin();
497 const APInt &smax = lhs.
smax().sgt(rhs.smax()) ? lhs.
smax() : rhs.smax();
505 const APInt &umin = lhs.
umin().ugt(rhs.umin()) ? lhs.
umin() : rhs.umin();
506 const APInt &umax = lhs.
umax().ugt(rhs.umax()) ? lhs.
umax() : rhs.umax();
514 const APInt &smin = lhs.
smin().slt(rhs.smin()) ? lhs.
smin() : rhs.smin();
515 const APInt &smax = lhs.
smax().slt(rhs.smax()) ? lhs.
smax() : rhs.smax();
523 const APInt &umin = lhs.
umin().ult(rhs.umin()) ? lhs.
umin() : rhs.umin();
524 const APInt &umax = lhs.
umax().ult(rhs.umax()) ? lhs.
umax() : rhs.umax();
536 static std::tuple<APInt, APInt>
538 APInt leftVal = bound.
umin(), rightVal = bound.
umax();
539 unsigned bitwidth = leftVal.getBitWidth();
540 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
541 leftVal.clearLowBits(differingBits);
542 rightVal.setLowBits(differingBits);
543 return std::make_tuple(std::move(leftVal), std::move(rightVal));
550 auto andi = [](
const APInt &a,
const APInt &b) -> std::optional<APInt> {
553 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
561 auto ori = [](
const APInt &a,
const APInt &b) -> std::optional<APInt> {
564 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
571 APInt leftVal = bound.
umin(), rightVal = bound.
umax();
572 unsigned bitwidth = leftVal.getBitWidth();
573 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
574 return APInt::getLowBitsSet(bitwidth, differingBits);
583 APInt res = lhs.
umin() ^ rhs.umin();
584 APInt
min = res & ~mask;
585 APInt
max = res | mask;
597 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
602 const APInt &r) -> std::optional<APInt> {
603 bool overflowed =
false;
604 APInt result = any(ovfFlags & OverflowFlags::Nuw)
606 : l.ushl_ov(r, overflowed);
607 return overflowed ? std::optional<APInt>() : result;
610 const APInt &r) -> std::optional<APInt> {
611 bool overflowed =
false;
612 APInt result = any(ovfFlags & OverflowFlags::Nsw)
614 : l.sshl_ov(r, overflowed);
615 return overflowed ? std::optional<APInt>() : result;
631 auto ashr = [](
const APInt &l,
const APInt &r) -> std::optional<APInt> {
632 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
643 auto lshr = [](
const APInt &l,
const APInt &r) -> std::optional<APInt> {
644 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
677 llvm_unreachable(
"unknown cmp predicate value");
703 return lhsConst && rhsConst && lhsConst == rhsConst;
740 APInt typeMax = APInt::getSignedMaxValue(width);
742 auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
743 if (!shapedTy.hasRank())
746 int64_t rank = shapedTy.getRank();
748 int64_t maxDim = rank - 1;
755 std::optional<ConstantIntRanges> result;
757 if (!result.has_value())
760 result = result->rangeUnion(thisResult);
762 for (int64_t i = minDim; i <= maxDim; ++i) {
763 int64_t length = shapedTy.getDimSize(i);
765 if (ShapedType::isDynamic(length))
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111,...
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.