19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "int-range-analysis"
40 std::function<std::optional<APInt>(
const APInt &,
const APInt &)>;
45 const APInt &minRight,
47 const APInt &maxRight,
bool isSigned) {
48 std::optional<APInt> maybeMin = op(minLeft, minRight);
49 std::optional<APInt> maybeMax = op(maxLeft, maxRight);
50 if (maybeMin && maybeMax)
59 unsigned width = lhs[0].getBitWidth();
61 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
63 isSigned ? APInt::getSignedMinValue(width) :
APInt::getZero(width);
64 for (
const APInt &left : lhs) {
65 for (
const APInt &right : rhs) {
66 std::optional<APInt> maybeThisResult = op(left, right);
69 APInt result = std::move(*maybeThisResult);
70 min = (isSigned ? result.slt(
min) : result.ult(
min)) ? result :
min;
71 max = (isSigned ? result.sgt(
max) : result.ugt(
max)) ? result :
max;
87 llvm::transform(argRanges, std::back_inserter(truncated),
97 LLVM_DEBUG(llvm::dbgs() <<
"Index handling: 64-bit result = " << sixtyFour
98 <<
" 32-bit = " << thirtyTwo <<
"\n");
99 bool truncEqual =
false;
102 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
105 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.
smin() &&
106 thirtyTwo.smax() == sixtyFourAsThirtyTwo.
smax());
109 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.
umin() &&
110 thirtyTwo.umax() == sixtyFourAsThirtyTwo.
umax());
121 unsigned int destWidth) {
122 APInt umin = range.
umin().zext(destWidth);
123 APInt umax = range.
umax().zext(destWidth);
124 APInt smin = range.
smin().sext(destWidth);
125 APInt smax = range.
smax().sext(destWidth);
126 return {umin, umax, smin, smax};
130 unsigned destWidth) {
131 APInt umin = range.
umin().zext(destWidth);
132 APInt umax = range.
umax().zext(destWidth);
137 unsigned destWidth) {
138 APInt smin = range.
smin().sext(destWidth);
139 APInt smax = range.
smax().sext(destWidth);
144 unsigned int destWidth) {
149 bool hasUnsignedRollover =
150 range.
umin().lshr(destWidth) != range.
umax().lshr(destWidth);
152 : range.umin().trunc(destWidth);
153 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
154 : range.umax().trunc(destWidth);
165 APInt sminHighPart = range.
smin().ashr(destWidth - 1);
166 APInt smaxHighPart = range.
smax().ashr(destWidth - 1);
167 bool hasSignedOverflow =
168 (sminHighPart != smaxHighPart) &&
169 !(sminHighPart.isAllOnes() &&
170 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
171 !(sminHighPart.isZero() && smaxHighPart.isZero());
172 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
173 : range.smin().trunc(destWidth);
174 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
175 : range.smax().trunc(destWidth);
176 return {umin, umax, smin, smax};
189 const APInt &b) -> std::optional<APInt> {
190 bool overflowed =
false;
191 APInt result = any(ovfFlags & OverflowFlags::Nuw)
193 : a.uadd_ov(b, overflowed);
194 return overflowed ? std::optional<APInt>() : result;
197 const APInt &b) -> std::optional<APInt> {
198 bool overflowed =
false;
199 APInt result = any(ovfFlags & OverflowFlags::Nsw)
201 : a.sadd_ov(b, overflowed);
202 return overflowed ? std::optional<APInt>() : result;
206 uadd, lhs.
umin(), rhs.umin(), lhs.
umax(), rhs.umax(),
false);
208 sadd, lhs.
smin(), rhs.smin(), lhs.
smax(), rhs.smax(),
true);
222 const APInt &b) -> std::optional<APInt> {
223 bool overflowed =
false;
224 APInt result = any(ovfFlags & OverflowFlags::Nuw)
226 : a.usub_ov(b, overflowed);
227 return overflowed ? std::optional<APInt>() : result;
230 const APInt &b) -> std::optional<APInt> {
231 bool overflowed =
false;
232 APInt result = any(ovfFlags & OverflowFlags::Nsw)
234 : a.ssub_ov(b, overflowed);
235 return overflowed ? std::optional<APInt>() : result;
238 usub, lhs.
umin(), rhs.umax(), lhs.
umax(), rhs.umin(),
false);
240 ssub, lhs.
smin(), rhs.smax(), lhs.
smax(), rhs.smin(),
true);
254 const APInt &b) -> std::optional<APInt> {
255 bool overflowed =
false;
256 APInt result = any(ovfFlags & OverflowFlags::Nuw)
258 : a.umul_ov(b, overflowed);
259 return overflowed ? std::optional<APInt>() : result;
262 const APInt &b) -> std::optional<APInt> {
263 bool overflowed =
false;
264 APInt result = any(ovfFlags & OverflowFlags::Nsw)
266 : a.smul_ov(b, overflowed);
267 return overflowed ? std::optional<APInt>() : result;
286 const APInt &lhs,
const APInt &rhs,
const APInt &result)>;
291 const APInt &lhsMin = lhs.
umin(), &lhsMax = lhs.
umax(), &rhsMin = rhs.
umin(),
292 &rhsMax = rhs.
umax();
293 if (!rhsMin.isZero() && !rhsMax.isZero()) {
294 auto udiv = [&fixup](
const APInt &a,
295 const APInt &b) -> std::optional<APInt> {
296 return fixup(a, b, a.udiv(b));
298 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
303 if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
304 umin = lhsMin.udiv(rhsMax);
314 [](
const APInt &lhs,
const APInt &rhs,
315 const APInt &result) {
return result; });
322 auto ceilDivUIFix = [](
const APInt &lhs,
const APInt &rhs,
323 const APInt &result) -> std::optional<APInt> {
324 if (!lhs.urem(rhs).isZero()) {
325 bool overflowed =
false;
327 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
328 return overflowed ? std::optional<APInt>() : corrected;
342 const APInt &lhsMin = lhs.
smin(), &lhsMax = lhs.
smax(), &rhsMin = rhs.
smin(),
343 &rhsMax = rhs.
smax();
344 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
347 auto sdiv = [&fixup](
const APInt &a,
348 const APInt &b) -> std::optional<APInt> {
349 bool overflowed =
false;
350 APInt result = a.sdiv_ov(b, overflowed);
351 return overflowed ? std::optional<APInt>() : fixup(a, b, result);
353 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
362 [](
const APInt &lhs,
const APInt &rhs,
363 const APInt &result) {
return result; });
370 auto ceilDivSIFix = [](
const APInt &lhs,
const APInt &rhs,
371 const APInt &result) -> std::optional<APInt> {
372 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
373 bool overflowed =
false;
375 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
376 return overflowed ? std::optional<APInt>() : corrected;
384 if (lhs.isMinSignedValue() && rhs.sgt(1)) {
390 if (lhs.
smin().isMinSignedValue() && lhs.
smax().sgt(lhs.
smin())) {
404 auto floorDivSIFix = [](
const APInt &lhs,
const APInt &rhs,
405 const APInt &result) -> std::optional<APInt> {
406 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
407 bool overflowed =
false;
409 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
410 return overflowed ? std::optional<APInt>() : corrected;
424 const APInt &lhsMin = lhs.
smin(), &lhsMax = lhs.
smax(), &rhsMin = rhs.smin(),
425 &rhsMax = rhs.smax();
427 unsigned width = rhsMax.getBitWidth();
428 APInt smin = APInt::getSignedMinValue(width);
429 APInt smax = APInt::getSignedMaxValue(width);
431 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
433 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
434 bool canNegativeDividend = lhsMin.isNegative();
435 bool canPositiveDividend = lhsMax.isStrictlyPositive();
437 APInt maxPositiveResult = maxDivisor - 1;
438 APInt minNegativeResult = -maxPositiveResult;
439 smin = canNegativeDividend ? minNegativeResult : zero;
440 smax = canPositiveDividend ? maxPositiveResult : zero;
442 if (rhsMin == rhsMax) {
443 if ((lhsMax - lhsMin).
ult(maxDivisor)) {
444 APInt minRem = lhsMin.srem(maxDivisor);
445 APInt maxRem = lhsMax.srem(maxDivisor);
446 if (minRem.sle(maxRem)) {
463 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
465 unsigned width = rhsMin.getBitWidth();
468 APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.
umax());
470 if (!rhsMin.isZero()) {
472 if (rhsMin == rhsMax) {
473 const APInt &lhsMin = lhs.
umin(), &lhsMax = lhs.
umax();
474 if ((lhsMax - lhsMin).
ult(rhsMax)) {
475 APInt minRem = lhsMin.urem(rhsMax);
476 APInt maxRem = lhsMax.urem(rhsMax);
477 if (minRem.ule(maxRem)) {
495 const APInt &smin = lhs.
smin().sgt(rhs.smin()) ? lhs.
smin() : rhs.smin();
496 const APInt &smax = lhs.
smax().sgt(rhs.smax()) ? lhs.
smax() : rhs.smax();
504 const APInt &umin = lhs.
umin().ugt(rhs.umin()) ? lhs.
umin() : rhs.umin();
505 const APInt &umax = lhs.
umax().ugt(rhs.umax()) ? lhs.
umax() : rhs.umax();
513 const APInt &smin = lhs.
smin().slt(rhs.smin()) ? lhs.
smin() : rhs.smin();
514 const APInt &smax = lhs.
smax().slt(rhs.smax()) ? lhs.
smax() : rhs.smax();
522 const APInt &umin = lhs.
umin().ult(rhs.umin()) ? lhs.
umin() : rhs.umin();
523 const APInt &umax = lhs.
umax().ult(rhs.umax()) ? lhs.
umax() : rhs.umax();
535 static std::tuple<APInt, APInt>
537 APInt leftVal = bound.
umin(), rightVal = bound.
umax();
538 unsigned bitwidth = leftVal.getBitWidth();
539 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
540 leftVal.clearLowBits(differingBits);
541 rightVal.setLowBits(differingBits);
542 return std::make_tuple(std::move(leftVal), std::move(rightVal));
549 auto andi = [](
const APInt &a,
const APInt &b) -> std::optional<APInt> {
552 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
560 auto ori = [](
const APInt &a,
const APInt &b) -> std::optional<APInt> {
563 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
570 APInt leftVal = bound.
umin(), rightVal = bound.
umax();
571 unsigned bitwidth = leftVal.getBitWidth();
572 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
573 return APInt::getLowBitsSet(bitwidth, differingBits);
582 APInt res = lhs.
umin() ^ rhs.umin();
583 APInt
min = res & ~mask;
584 APInt
max = res | mask;
596 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
601 const APInt &r) -> std::optional<APInt> {
602 bool overflowed =
false;
603 APInt result = any(ovfFlags & OverflowFlags::Nuw)
605 : l.ushl_ov(r, overflowed);
606 return overflowed ? std::optional<APInt>() : result;
609 const APInt &r) -> std::optional<APInt> {
610 bool overflowed =
false;
611 APInt result = any(ovfFlags & OverflowFlags::Nsw)
613 : l.sshl_ov(r, overflowed);
614 return overflowed ? std::optional<APInt>() : result;
630 auto ashr = [](
const APInt &l,
const APInt &r) -> std::optional<APInt> {
631 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
642 auto lshr = [](
const APInt &l,
const APInt &r) -> std::optional<APInt> {
643 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
676 llvm_unreachable(
"unknown cmp predicate value");
702 return lhsConst && rhsConst && lhsConst == rhsConst;
739 APInt typeMax = APInt::getSignedMaxValue(width);
741 auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
742 if (!shapedTy.hasRank())
745 int64_t rank = shapedTy.getRank();
747 int64_t maxDim = rank - 1;
754 std::optional<ConstantIntRanges> result;
756 if (!result.has_value())
759 result = result->rangeUnion(thisResult);
761 for (int64_t i = minDim; i <= maxDim; ++i) {
762 int64_t length = shapedTy.getDimSize(i);
764 if (ShapedType::isDynamic(length))
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111,...
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.