MLIR 22.0.0git
InferIntRangeCommon.cpp
Go to the documentation of this file.
1//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains implementations of range inference for operations that are
10// common to both the `arith` and `index` dialects to facilitate reuse.
11//
12//===----------------------------------------------------------------------===//
13
15
18
19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/STLExtras.h"
21
22#include "llvm/Support/Debug.h"
23
24#include <iterator>
25#include <optional>
26
27using namespace mlir;
28
29#define DEBUG_TYPE "int-range-analysis"
30
31//===----------------------------------------------------------------------===//
32// General utilities
33//===----------------------------------------------------------------------===//
34
35/// Function that evaluates the result of doing something on arithmetic
36/// constants and returns std::nullopt on overflow.
38 function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
40 std::function<std::optional<APInt>(const APInt &, const APInt &)>;
41
42/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
43/// If either computation overflows, make the result unbounded.
44static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
45 const APInt &minRight,
46 const APInt &maxLeft,
47 const APInt &maxRight, bool isSigned) {
48 std::optional<APInt> maybeMin = op(minLeft, minRight);
49 std::optional<APInt> maybeMax = op(maxLeft, maxRight);
50 if (maybeMin && maybeMax)
51 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
52 return ConstantIntRanges::maxRange(minLeft.getBitWidth());
53}
54
55/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
56/// ignoring unbounded values. Returns the maximal range if `op` overflows.
58 ArrayRef<APInt> rhs, bool isSigned) {
59 unsigned width = lhs[0].getBitWidth();
60 APInt min =
61 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
62 APInt max =
63 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
64 for (const APInt &left : lhs) {
65 for (const APInt &right : rhs) {
66 std::optional<APInt> maybeThisResult = op(left, right);
67 if (!maybeThisResult)
68 return ConstantIntRanges::maxRange(width);
69 APInt result = std::move(*maybeThisResult);
70 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
71 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
72 }
73 }
74 return ConstantIntRanges::range(min, max, isSigned);
75}
76
77//===----------------------------------------------------------------------===//
78// Ext, trunc, index op handling
79//===----------------------------------------------------------------------===//
80
84 intrange::CmpMode mode) {
85 ConstantIntRanges sixtyFour = inferFn(argRanges);
87 llvm::transform(argRanges, std::back_inserter(truncated),
88 [](const ConstantIntRanges &range) {
89 return truncRange(range, /*destWidth=*/indexMinWidth);
90 });
91 ConstantIntRanges thirtyTwo = inferFn(truncated);
92 ConstantIntRanges thirtyTwoAsSixtyFour =
93 extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
94 ConstantIntRanges sixtyFourAsThirtyTwo =
95 truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
96
97 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
98 << " 32-bit = " << thirtyTwo << "\n");
99 bool truncEqual = false;
100 switch (mode) {
102 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
103 break;
105 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
106 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
107 break;
109 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
110 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
111 break;
112 }
113 if (truncEqual)
114 // Returing the 64-bit result preserves more information.
115 return sixtyFour;
116 ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
117 return merged;
118}
119
121 unsigned int destWidth) {
122 APInt umin = range.umin().zext(destWidth);
123 APInt umax = range.umax().zext(destWidth);
124 APInt smin = range.smin().sext(destWidth);
125 APInt smax = range.smax().sext(destWidth);
126 return {umin, umax, smin, smax};
127}
128
130 unsigned destWidth) {
131 APInt umin = range.umin().zext(destWidth);
132 APInt umax = range.umax().zext(destWidth);
133 return ConstantIntRanges::fromUnsigned(umin, umax);
134}
135
137 unsigned destWidth) {
138 APInt smin = range.smin().sext(destWidth);
139 APInt smax = range.smax().sext(destWidth);
140 return ConstantIntRanges::fromSigned(smin, smax);
141}
142
144 unsigned int destWidth) {
145 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
146 // the range of the resulting value is not contiguous ind includes 0.
147 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
148 // but you can't truncate [255, 257] similarly.
149 bool hasUnsignedRollover =
150 range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
151 APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
152 : range.umin().trunc(destWidth);
153 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
154 : range.umax().trunc(destWidth);
155
156 // Signed post-truncation rollover will not occur when either:
157 // - The high parts of the min and max, plus the sign bit, are the same
158 // - The high halves + sign bit of the min and max are either all 1s or all 0s
159 // and you won't create a [positive, negative] range by truncating.
160 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
161 // but not [255, 257]_i16 to a range of i8s. You can also truncate
162 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
163 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
164 // will truncate to 0x7e, which is greater than 0
165 APInt sminHighPart = range.smin().ashr(destWidth - 1);
166 APInt smaxHighPart = range.smax().ashr(destWidth - 1);
167 bool hasSignedOverflow =
168 (sminHighPart != smaxHighPart) &&
169 !(sminHighPart.isAllOnes() &&
170 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
171 !(sminHighPart.isZero() && smaxHighPart.isZero());
172 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
173 : range.smin().trunc(destWidth);
174 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
175 : range.smax().trunc(destWidth);
176 return {umin, umax, smin, smax};
177}
178
179//===----------------------------------------------------------------------===//
180// Addition
181//===----------------------------------------------------------------------===//
182
185 OverflowFlags ovfFlags) {
186 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
187
188 ConstArithStdFn uadd = [=](const APInt &a,
189 const APInt &b) -> std::optional<APInt> {
190 bool overflowed = false;
191 APInt result = any(ovfFlags & OverflowFlags::Nuw)
192 ? a.uadd_sat(b)
193 : a.uadd_ov(b, overflowed);
194 return overflowed ? std::optional<APInt>() : result;
195 };
196 ConstArithStdFn sadd = [=](const APInt &a,
197 const APInt &b) -> std::optional<APInt> {
198 bool overflowed = false;
199 APInt result = any(ovfFlags & OverflowFlags::Nsw)
200 ? a.sadd_sat(b)
201 : a.sadd_ov(b, overflowed);
202 return overflowed ? std::optional<APInt>() : result;
203 };
204
206 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
208 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
209 return urange.intersection(srange);
210}
211
212//===----------------------------------------------------------------------===//
213// Subtraction
214//===----------------------------------------------------------------------===//
215
218 OverflowFlags ovfFlags) {
219 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
220
221 ConstArithStdFn usub = [=](const APInt &a,
222 const APInt &b) -> std::optional<APInt> {
223 bool overflowed = false;
224 APInt result = any(ovfFlags & OverflowFlags::Nuw)
225 ? a.usub_sat(b)
226 : a.usub_ov(b, overflowed);
227 return overflowed ? std::optional<APInt>() : result;
228 };
229 ConstArithStdFn ssub = [=](const APInt &a,
230 const APInt &b) -> std::optional<APInt> {
231 bool overflowed = false;
232 APInt result = any(ovfFlags & OverflowFlags::Nsw)
233 ? a.ssub_sat(b)
234 : a.ssub_ov(b, overflowed);
235 return overflowed ? std::optional<APInt>() : result;
236 };
238 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
240 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
241 return urange.intersection(srange);
242}
243
244//===----------------------------------------------------------------------===//
245// Multiplication
246//===----------------------------------------------------------------------===//
247
250 OverflowFlags ovfFlags) {
251 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
252
253 ConstArithStdFn umul = [=](const APInt &a,
254 const APInt &b) -> std::optional<APInt> {
255 bool overflowed = false;
256 APInt result = any(ovfFlags & OverflowFlags::Nuw)
257 ? a.umul_sat(b)
258 : a.umul_ov(b, overflowed);
259 return overflowed ? std::optional<APInt>() : result;
260 };
261 ConstArithStdFn smul = [=](const APInt &a,
262 const APInt &b) -> std::optional<APInt> {
263 bool overflowed = false;
264 APInt result = any(ovfFlags & OverflowFlags::Nsw)
265 ? a.smul_sat(b)
266 : a.smul_ov(b, overflowed);
267 return overflowed ? std::optional<APInt>() : result;
268 };
269
270 ConstantIntRanges urange =
271 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
272 /*isSigned=*/false);
273 ConstantIntRanges srange =
274 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
275 /*isSigned=*/true);
276 return urange.intersection(srange);
277}
278
279//===----------------------------------------------------------------------===//
280// DivU, CeilDivU (Unsigned division)
281//===----------------------------------------------------------------------===//
282
283/// Fix up division results (ex. for ceiling and floor), returning an APInt
284/// if there has been no overflow
286 const APInt &lhs, const APInt &rhs, const APInt &result)>;
287
289 const ConstantIntRanges &rhs,
290 DivisionFixupFn fixup) {
291 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
292 &rhsMax = rhs.umax();
293 if (!rhsMin.isZero() && !rhsMax.isZero()) {
294 auto udiv = [&fixup](const APInt &a,
295 const APInt &b) -> std::optional<APInt> {
296 return fixup(a, b, a.udiv(b));
297 };
298 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
299 /*isSigned=*/false);
300 }
301
302 APInt umin = APInt::getZero(rhsMin.getBitWidth());
303 if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
304 umin = lhsMin.udiv(rhsMax);
305
306 // X u/ Y u<= X.
307 const APInt &umax = lhsMax;
308 return ConstantIntRanges::fromUnsigned(umin, umax);
309}
310
313 return inferDivURange(argRanges[0], argRanges[1],
314 [](const APInt &lhs, const APInt &rhs,
315 const APInt &result) { return result; });
316}
317
320 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
321
322 auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
323 const APInt &result) -> std::optional<APInt> {
324 if (!lhs.urem(rhs).isZero()) {
325 bool overflowed = false;
326 APInt corrected =
327 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
328 return overflowed ? std::optional<APInt>() : corrected;
329 }
330 return result;
331 };
332 return inferDivURange(lhs, rhs, ceilDivUIFix);
333}
334
335//===----------------------------------------------------------------------===//
336// DivS, CeilDivS, FloorDivS (Signed division)
337//===----------------------------------------------------------------------===//
338
340 const ConstantIntRanges &rhs,
341 DivisionFixupFn fixup) {
342 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
343 &rhsMax = rhs.smax();
344 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
345
346 if (canDivide) {
347 auto sdiv = [&fixup](const APInt &a,
348 const APInt &b) -> std::optional<APInt> {
349 bool overflowed = false;
350 APInt result = a.sdiv_ov(b, overflowed);
351 return overflowed ? std::optional<APInt>() : fixup(a, b, result);
352 };
353 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
354 /*isSigned=*/true);
355 }
356 return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
357}
358
361 return inferDivSRange(argRanges[0], argRanges[1],
362 [](const APInt &lhs, const APInt &rhs,
363 const APInt &result) { return result; });
364}
365
368 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
369
370 auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
371 const APInt &result) -> std::optional<APInt> {
372 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
373 bool overflowed = false;
374 APInt corrected =
375 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
376 return overflowed ? std::optional<APInt>() : corrected;
377 }
378 // Special case where the usual implementation of ceilDiv causes
379 // INT_MIN / [positive number] to be positive. This doesn't match the
380 // definition of signed ceiling division mathematically, but it prevents
381 // inconsistent constant-folding results. This arises because (-int_min) is
382 // still negative, so -(-int_min / b) is -(int_min / b), which is
383 // positive See #115293.
384 if (lhs.isMinSignedValue() && rhs.sgt(1)) {
385 return -result;
386 }
387 return result;
388 };
390 if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) {
391 // If lhs range includes INT_MIN and lhs is not a single value, we can
392 // suddenly wrap to positive val, skipping entire negative range, add
393 // [INT_MIN + 1, smax()] range to the result to handle this.
394 auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
395 result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
396 }
397 return result;
398}
399
402 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
403
404 auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
405 const APInt &result) -> std::optional<APInt> {
406 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
407 bool overflowed = false;
408 APInt corrected =
409 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
410 return overflowed ? std::optional<APInt>() : corrected;
411 }
412 return result;
413 };
414 return inferDivSRange(lhs, rhs, floorDivSIFix);
415}
416
417//===----------------------------------------------------------------------===//
418// Signed remainder (RemS)
419//===----------------------------------------------------------------------===//
420
423 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
424 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
425 &rhsMax = rhs.smax();
426
427 unsigned width = rhsMax.getBitWidth();
428 APInt smin = APInt::getSignedMinValue(width);
429 APInt smax = APInt::getSignedMaxValue(width);
430 // No bounds if zero could be a divisor.
431 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
432 if (canBound) {
433 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
434 bool canNegativeDividend = lhsMin.isNegative();
435 bool canPositiveDividend = lhsMax.isStrictlyPositive();
436 APInt zero = APInt::getZero(maxDivisor.getBitWidth());
437 APInt maxPositiveResult = maxDivisor - 1;
438 APInt minNegativeResult = -maxPositiveResult;
439 smin = canNegativeDividend ? minNegativeResult : zero;
440 smax = canPositiveDividend ? maxPositiveResult : zero;
441 // Special case: sweeping out a contiguous range in N/[modulus].
442 if (rhsMin == rhsMax) {
443 if ((lhsMax - lhsMin).ult(maxDivisor)) {
444 APInt minRem = lhsMin.srem(maxDivisor);
445 APInt maxRem = lhsMax.srem(maxDivisor);
446 if (minRem.sle(maxRem)) {
447 smin = minRem;
448 smax = maxRem;
449 }
450 }
451 }
452 }
453 return ConstantIntRanges::fromSigned(smin, smax);
454}
455
456//===----------------------------------------------------------------------===//
457// Unsigned remainder (RemU)
458//===----------------------------------------------------------------------===//
459
462 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
463 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
464
465 unsigned width = rhsMin.getBitWidth();
466 APInt umin = APInt::getZero(width);
467 // Remainder can't be larger than either of its arguments.
468 APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
469
470 if (!rhsMin.isZero()) {
471 // Special case: sweeping out a contiguous range in N/[modulus]
472 if (rhsMin == rhsMax) {
473 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
474 if ((lhsMax - lhsMin).ult(rhsMax)) {
475 APInt minRem = lhsMin.urem(rhsMax);
476 APInt maxRem = lhsMax.urem(rhsMax);
477 if (minRem.ule(maxRem)) {
478 umin = minRem;
479 umax = maxRem;
480 }
481 }
482 }
483 }
484 return ConstantIntRanges::fromUnsigned(umin, umax);
485}
486
487//===----------------------------------------------------------------------===//
488// Max and min (MaxS, MaxU, MinS, MinU)
489//===----------------------------------------------------------------------===//
490
493 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
494
495 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
496 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
497 return ConstantIntRanges::fromSigned(smin, smax);
498}
499
502 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
503
504 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
505 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
506 return ConstantIntRanges::fromUnsigned(umin, umax);
507}
508
511 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
512
513 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
514 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
515 return ConstantIntRanges::fromSigned(smin, smax);
516}
517
520 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
521
522 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
523 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
524 return ConstantIntRanges::fromUnsigned(umin, umax);
525}
526
527//===----------------------------------------------------------------------===//
528// Bitwise operators (And, Or, Xor)
529//===----------------------------------------------------------------------===//
530
531/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
532/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
533/// that both bonuds have in common. This gives us a consertive approximation
534/// for what values can be passed to bitwise operations.
535static std::tuple<APInt, APInt>
537 APInt leftVal = bound.umin(), rightVal = bound.umax();
538 unsigned bitwidth = leftVal.getBitWidth();
539 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
540 leftVal.clearLowBits(differingBits);
541 rightVal.setLowBits(differingBits);
542 return std::make_tuple(std::move(leftVal), std::move(rightVal));
543}
544
547 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
548 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
549 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
550 return a & b;
551 };
552 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
553 /*isSigned=*/false);
554}
555
558 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
559 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
560 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
561 return a | b;
562 };
563 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
564 /*isSigned=*/false);
565}
566
567/// Get bitmask of all bits which can change while iterating in
568/// [bound.umin(), bound.umax()].
569static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
570 APInt leftVal = bound.umin(), rightVal = bound.umax();
571 unsigned bitwidth = leftVal.getBitWidth();
572 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
573 return APInt::getLowBitsSet(bitwidth, differingBits);
574}
575
578 // Construct mask of varying bits for both ranges, xor values and then replace
579 // masked bits with 0s and 1s to get min and max values respectively.
580 ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
582 APInt res = lhs.umin() ^ rhs.umin();
583 APInt min = res & ~mask;
584 APInt max = res | mask;
586}
587
588//===----------------------------------------------------------------------===//
589// Shifts (Shl, ShrS, ShrU)
590//===----------------------------------------------------------------------===//
591
594 OverflowFlags ovfFlags) {
595 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
596 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
597
598 // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
599 // 2^rhs.
600 ConstArithStdFn ushl = [=](const APInt &l,
601 const APInt &r) -> std::optional<APInt> {
602 bool overflowed = false;
603 APInt result = any(ovfFlags & OverflowFlags::Nuw)
604 ? l.ushl_sat(r)
605 : l.ushl_ov(r, overflowed);
606 return overflowed ? std::optional<APInt>() : result;
607 };
608 ConstArithStdFn sshl = [=](const APInt &l,
609 const APInt &r) -> std::optional<APInt> {
610 bool overflowed = false;
611 APInt result = any(ovfFlags & OverflowFlags::Nsw)
612 ? l.sshl_sat(r)
613 : l.sshl_ov(r, overflowed);
614 return overflowed ? std::optional<APInt>() : result;
615 };
616
617 ConstantIntRanges urange =
618 minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
619 /*isSigned=*/false);
620 ConstantIntRanges srange =
621 minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
622 /*isSigned=*/true);
623 return urange.intersection(srange);
624}
625
628 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
629
630 auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
631 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
632 };
633
634 return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
635 /*isSigned=*/true);
636}
637
640 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
641
642 auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
643 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
644 };
645 return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
646 /*isSigned=*/false);
647}
648
649//===----------------------------------------------------------------------===//
650// Comparisons (Cmp)
651//===----------------------------------------------------------------------===//
652
678
680 const ConstantIntRanges &lhs,
681 const ConstantIntRanges &rhs) {
682 switch (pred) {
684 return lhs.smax().sle(rhs.smin());
686 return lhs.smax().slt(rhs.smin());
688 return lhs.umax().ule(rhs.umin());
690 return lhs.umax().ult(rhs.umin());
692 return lhs.smin().sge(rhs.smax());
694 return lhs.smin().sgt(rhs.smax());
696 return lhs.umin().uge(rhs.umax());
698 return lhs.umin().ugt(rhs.umax());
700 std::optional<APInt> lhsConst = lhs.getConstantValue();
701 std::optional<APInt> rhsConst = rhs.getConstantValue();
702 return lhsConst && rhsConst && lhsConst == rhsConst;
703 }
705 // While equality requires that there is an interpration of the preceeding
706 // computations that produces equal constants, whether that be signed or
707 // unsigned, statically determining inequality requires that neither
708 // interpretation produce potentially overlapping ranges.
713 return sne && une;
714 }
715 }
716 return false;
717}
718
720 const ConstantIntRanges &lhs,
721 const ConstantIntRanges &rhs) {
722 if (isStaticallyTrue(pred, lhs, rhs))
723 return true;
725 return false;
726 return std::nullopt;
727}
728
729//===----------------------------------------------------------------------===//
730// Shaped type dimension accessors / ShapedDimOpInterface
731//===----------------------------------------------------------------------===//
732
735 const IntegerValueRange &maybeDim) {
736 unsigned width =
737 ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
738 APInt zero = APInt::getZero(width);
739 APInt typeMax = APInt::getSignedMaxValue(width);
740
741 auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
742 if (!shapedTy.hasRank())
743 return ConstantIntRanges::fromSigned(zero, typeMax);
744
745 int64_t rank = shapedTy.getRank();
746 int64_t minDim = 0;
747 int64_t maxDim = rank - 1;
748 if (!maybeDim.isUninitialized()) {
749 const ConstantIntRanges &dim = maybeDim.getValue();
750 minDim = std::max(minDim, dim.smin().getSExtValue());
751 maxDim = std::min(maxDim, dim.smax().getSExtValue());
752 }
753
754 std::optional<ConstantIntRanges> result;
755 auto joinResult = [&](const ConstantIntRanges &thisResult) {
756 if (!result.has_value())
757 result = thisResult;
758 else
759 result = result->rangeUnion(thisResult);
760 };
761 for (int64_t i = minDim; i <= maxDim; ++i) {
762 int64_t length = shapedTy.getDimSize(i);
763
764 if (ShapedType::isDynamic(length))
765 joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
766 else
767 joinResult(ConstantIntRanges::constant(APInt(width, length)));
768 }
769 return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
770}
lhs
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv?
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
function_ref< std::optional< APInt >(const APInt &, const APInt &)> ConstArithFn
Function that evaluates the result of doing something on arithmetic constants and returns std::nullop...
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
function_ref< std::optional< APInt >( const APInt &lhs, const APInt &rhs, const APInt &result)> DivisionFixupFn
Fix up division results (ex.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152