MLIR 23.0.0git
InferIntRangeCommon.cpp
Go to the documentation of this file.
1//===- InferIntRangeCommon.cpp - Inference for common ops ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains implementations of range inference for operations that are
10// common to both the `arith` and `index` dialects to facilitate reuse.
11//
12//===----------------------------------------------------------------------===//
13
15
16#include "mlir/IR/AffineExpr.h"
19
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/STLExtras.h"
22
23#include "llvm/Support/Debug.h"
24
25#include <iterator>
26#include <optional>
27
28using namespace mlir;
29
30#define DEBUG_TYPE "int-range-analysis"
31
32//===----------------------------------------------------------------------===//
33// General utilities
34//===----------------------------------------------------------------------===//
35
36/// Function that evaluates the result of doing something on arithmetic
37/// constants and returns std::nullopt on overflow.
39 function_ref<std::optional<APInt>(const APInt &, const APInt &)>;
41 std::function<std::optional<APInt>(const APInt &, const APInt &)>;
42
43/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
44/// If either computation overflows, make the result unbounded.
45static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
46 const APInt &minRight,
47 const APInt &maxLeft,
48 const APInt &maxRight, bool isSigned) {
49 std::optional<APInt> maybeMin = op(minLeft, minRight);
50 std::optional<APInt> maybeMax = op(maxLeft, maxRight);
51 if (maybeMin && maybeMax)
52 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
53 return ConstantIntRanges::maxRange(minLeft.getBitWidth());
54}
55
56/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
57/// ignoring unbounded values. Returns the maximal range if `op` overflows.
59 ArrayRef<APInt> rhs, bool isSigned) {
60 unsigned width = lhs[0].getBitWidth();
61 APInt min =
62 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
63 APInt max =
64 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
65 for (const APInt &left : lhs) {
66 for (const APInt &right : rhs) {
67 std::optional<APInt> maybeThisResult = op(left, right);
68 if (!maybeThisResult)
69 return ConstantIntRanges::maxRange(width);
70 APInt result = std::move(*maybeThisResult);
71 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
72 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
73 }
74 }
75 return ConstantIntRanges::range(min, max, isSigned);
76}
77
78//===----------------------------------------------------------------------===//
79// Ext, trunc, index op handling
80//===----------------------------------------------------------------------===//
81
85 intrange::CmpMode mode) {
86 ConstantIntRanges sixtyFour = inferFn(argRanges);
88 llvm::transform(argRanges, std::back_inserter(truncated),
89 [](const ConstantIntRanges &range) {
90 return truncRange(range, /*destWidth=*/indexMinWidth);
91 });
92 ConstantIntRanges thirtyTwo = inferFn(truncated);
93 ConstantIntRanges thirtyTwoAsSixtyFour =
94 extRange(thirtyTwo, /*destWidth=*/indexMaxWidth);
95 ConstantIntRanges sixtyFourAsThirtyTwo =
96 truncRange(sixtyFour, /*destWidth=*/indexMinWidth);
97
98 LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour
99 << " 32-bit = " << thirtyTwo << "\n");
100 bool truncEqual = false;
101 switch (mode) {
103 truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
104 break;
106 truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
107 thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax());
108 break;
110 truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() &&
111 thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax());
112 break;
113 }
114 if (truncEqual)
115 // Returing the 64-bit result preserves more information.
116 return sixtyFour;
117 ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
118 return merged;
119}
120
122 unsigned int destWidth) {
123 APInt umin = range.umin().zext(destWidth);
124 APInt umax = range.umax().zext(destWidth);
125 APInt smin = range.smin().sext(destWidth);
126 APInt smax = range.smax().sext(destWidth);
127 return {umin, umax, smin, smax};
128}
129
131 unsigned destWidth) {
132 APInt umin = range.umin().zext(destWidth);
133 APInt umax = range.umax().zext(destWidth);
134 return ConstantIntRanges::fromUnsigned(umin, umax);
135}
136
138 unsigned destWidth) {
139 APInt smin = range.smin().sext(destWidth);
140 APInt smax = range.smax().sext(destWidth);
141 return ConstantIntRanges::fromSigned(smin, smax);
142}
143
145 unsigned int destWidth) {
146 // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
147 // the range of the resulting value is not contiguous ind includes 0.
148 // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
149 // but you can't truncate [255, 257] similarly.
150 bool hasUnsignedRollover =
151 range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
152 APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
153 : range.umin().trunc(destWidth);
154 APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
155 : range.umax().trunc(destWidth);
156
157 // Signed post-truncation rollover will not occur when either:
158 // - The high parts of the min and max, plus the sign bit, are the same
159 // - The high halves + sign bit of the min and max are either all 1s or all 0s
160 // and you won't create a [positive, negative] range by truncating.
161 // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
162 // but not [255, 257]_i16 to a range of i8s. You can also truncate
163 // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
164 // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
165 // will truncate to 0x7e, which is greater than 0
166 APInt sminHighPart = range.smin().ashr(destWidth - 1);
167 APInt smaxHighPart = range.smax().ashr(destWidth - 1);
168 bool hasSignedOverflow =
169 (sminHighPart != smaxHighPart) &&
170 !(sminHighPart.isAllOnes() &&
171 (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
172 !(sminHighPart.isZero() && smaxHighPart.isZero());
173 APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
174 : range.smin().trunc(destWidth);
175 APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
176 : range.smax().trunc(destWidth);
177 return {umin, umax, smin, smax};
178}
179
180//===----------------------------------------------------------------------===//
181// Addition
182//===----------------------------------------------------------------------===//
183
186 OverflowFlags ovfFlags) {
187 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
188
189 ConstArithStdFn uadd = [=](const APInt &a,
190 const APInt &b) -> std::optional<APInt> {
191 bool overflowed = false;
192 APInt result = any(ovfFlags & OverflowFlags::Nuw)
193 ? a.uadd_sat(b)
194 : a.uadd_ov(b, overflowed);
195 return overflowed ? std::optional<APInt>() : result;
196 };
197 ConstArithStdFn sadd = [=](const APInt &a,
198 const APInt &b) -> std::optional<APInt> {
199 bool overflowed = false;
200 APInt result = any(ovfFlags & OverflowFlags::Nsw)
201 ? a.sadd_sat(b)
202 : a.sadd_ov(b, overflowed);
203 return overflowed ? std::optional<APInt>() : result;
204 };
205
207 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
209 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
210 return urange.intersection(srange);
211}
212
213//===----------------------------------------------------------------------===//
214// Subtraction
215//===----------------------------------------------------------------------===//
216
219 OverflowFlags ovfFlags) {
220 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
221
222 ConstArithStdFn usub = [=](const APInt &a,
223 const APInt &b) -> std::optional<APInt> {
224 bool overflowed = false;
225 APInt result = any(ovfFlags & OverflowFlags::Nuw)
226 ? a.usub_sat(b)
227 : a.usub_ov(b, overflowed);
228 return overflowed ? std::optional<APInt>() : result;
229 };
230 ConstArithStdFn ssub = [=](const APInt &a,
231 const APInt &b) -> std::optional<APInt> {
232 bool overflowed = false;
233 APInt result = any(ovfFlags & OverflowFlags::Nsw)
234 ? a.ssub_sat(b)
235 : a.ssub_ov(b, overflowed);
236 return overflowed ? std::optional<APInt>() : result;
237 };
239 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
241 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
242 return urange.intersection(srange);
243}
244
245//===----------------------------------------------------------------------===//
246// Multiplication
247//===----------------------------------------------------------------------===//
248
251 OverflowFlags ovfFlags) {
252 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
253
254 ConstArithStdFn umul = [=](const APInt &a,
255 const APInt &b) -> std::optional<APInt> {
256 bool overflowed = false;
257 APInt result = any(ovfFlags & OverflowFlags::Nuw)
258 ? a.umul_sat(b)
259 : a.umul_ov(b, overflowed);
260 return overflowed ? std::optional<APInt>() : result;
261 };
262 ConstArithStdFn smul = [=](const APInt &a,
263 const APInt &b) -> std::optional<APInt> {
264 bool overflowed = false;
265 APInt result = any(ovfFlags & OverflowFlags::Nsw)
266 ? a.smul_sat(b)
267 : a.smul_ov(b, overflowed);
268 return overflowed ? std::optional<APInt>() : result;
269 };
270
271 ConstantIntRanges urange =
272 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
273 /*isSigned=*/false);
274 ConstantIntRanges srange =
275 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
276 /*isSigned=*/true);
277 return urange.intersection(srange);
278}
279
280//===----------------------------------------------------------------------===//
281// DivU, CeilDivU (Unsigned division)
282//===----------------------------------------------------------------------===//
283
284/// Fix up division results (ex. for ceiling and floor), returning an APInt
285/// if there has been no overflow
287 const APInt &lhs, const APInt &rhs, const APInt &result)>;
288
290 const ConstantIntRanges &rhs,
291 DivisionFixupFn fixup) {
292 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
293 &rhsMax = rhs.umax();
294 if (!rhsMin.isZero() && !rhsMax.isZero()) {
295 auto udiv = [&fixup](const APInt &a,
296 const APInt &b) -> std::optional<APInt> {
297 return fixup(a, b, a.udiv(b));
298 };
299 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
300 /*isSigned=*/false);
301 }
302
303 APInt umin = APInt::getZero(rhsMin.getBitWidth());
304 if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
305 umin = lhsMin.udiv(rhsMax);
306
307 // X u/ Y u<= X.
308 const APInt &umax = lhsMax;
309 return ConstantIntRanges::fromUnsigned(umin, umax);
310}
311
314 return inferDivURange(argRanges[0], argRanges[1],
315 [](const APInt &lhs, const APInt &rhs,
316 const APInt &result) { return result; });
317}
318
321 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
322
323 auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
324 const APInt &result) -> std::optional<APInt> {
325 if (!lhs.urem(rhs).isZero()) {
326 bool overflowed = false;
327 APInt corrected =
328 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
329 return overflowed ? std::optional<APInt>() : corrected;
330 }
331 return result;
332 };
333 return inferDivURange(lhs, rhs, ceilDivUIFix);
334}
335
336//===----------------------------------------------------------------------===//
337// DivS, CeilDivS, FloorDivS (Signed division)
338//===----------------------------------------------------------------------===//
339
341 const ConstantIntRanges &rhs,
342 DivisionFixupFn fixup) {
343 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
344 &rhsMax = rhs.smax();
345 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
346
347 if (canDivide) {
348 auto sdiv = [&fixup](const APInt &a,
349 const APInt &b) -> std::optional<APInt> {
350 bool overflowed = false;
351 APInt result = a.sdiv_ov(b, overflowed);
352 return overflowed ? std::optional<APInt>() : fixup(a, b, result);
353 };
354 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
355 /*isSigned=*/true);
356 }
357 return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
358}
359
362 return inferDivSRange(argRanges[0], argRanges[1],
363 [](const APInt &lhs, const APInt &rhs,
364 const APInt &result) { return result; });
365}
366
369 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
370
371 auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
372 const APInt &result) -> std::optional<APInt> {
373 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
374 bool overflowed = false;
375 APInt corrected =
376 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
377 return overflowed ? std::optional<APInt>() : corrected;
378 }
379 // Special case where the usual implementation of ceilDiv causes
380 // INT_MIN / [positive number] to be positive. This doesn't match the
381 // definition of signed ceiling division mathematically, but it prevents
382 // inconsistent constant-folding results. This arises because (-int_min) is
383 // still negative, so -(-int_min / b) is -(int_min / b), which is
384 // positive See #115293.
385 if (lhs.isMinSignedValue() && rhs.sgt(1)) {
386 return -result;
387 }
388 return result;
389 };
391 if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) {
392 // If lhs range includes INT_MIN and lhs is not a single value, we can
393 // suddenly wrap to positive val, skipping entire negative range, add
394 // [INT_MIN + 1, smax()] range to the result to handle this.
395 auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
396 result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
397 }
398 return result;
399}
400
403 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
404
405 auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
406 const APInt &result) -> std::optional<APInt> {
407 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
408 bool overflowed = false;
409 APInt corrected =
410 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
411 return overflowed ? std::optional<APInt>() : corrected;
412 }
413 return result;
414 };
415 return inferDivSRange(lhs, rhs, floorDivSIFix);
416}
417
418//===----------------------------------------------------------------------===//
419// Signed remainder (RemS)
420//===----------------------------------------------------------------------===//
421
424 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
425 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
426 &rhsMax = rhs.smax();
427
428 unsigned width = rhsMax.getBitWidth();
429 APInt smin = APInt::getSignedMinValue(width);
430 APInt smax = APInt::getSignedMaxValue(width);
431 // No bounds if zero could be a divisor.
432 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
433 if (canBound) {
434 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
435 bool canNegativeDividend = lhsMin.isNegative();
436 bool canPositiveDividend = lhsMax.isStrictlyPositive();
437 APInt zero = APInt::getZero(maxDivisor.getBitWidth());
438 APInt maxPositiveResult = maxDivisor - 1;
439 APInt minNegativeResult = -maxPositiveResult;
440 smin = canNegativeDividend ? minNegativeResult : zero;
441 smax = canPositiveDividend ? maxPositiveResult : zero;
442 // Special case: sweeping out a contiguous range in N/[modulus].
443 if (rhsMin == rhsMax) {
444 if ((lhsMax - lhsMin).ult(maxDivisor)) {
445 APInt minRem = lhsMin.srem(maxDivisor);
446 APInt maxRem = lhsMax.srem(maxDivisor);
447 if (minRem.sle(maxRem)) {
448 smin = minRem;
449 smax = maxRem;
450 }
451 }
452 }
453 }
454 return ConstantIntRanges::fromSigned(smin, smax);
455}
456
457//===----------------------------------------------------------------------===//
458// Unsigned remainder (RemU)
459//===----------------------------------------------------------------------===//
460
463 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
464 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
465
466 unsigned width = rhsMin.getBitWidth();
467 APInt umin = APInt::getZero(width);
468 // Remainder can't be larger than either of its arguments.
469 APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
470
471 if (!rhsMin.isZero()) {
472 // Special case: sweeping out a contiguous range in N/[modulus]
473 if (rhsMin == rhsMax) {
474 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
475 if ((lhsMax - lhsMin).ult(rhsMax)) {
476 APInt minRem = lhsMin.urem(rhsMax);
477 APInt maxRem = lhsMax.urem(rhsMax);
478 if (minRem.ule(maxRem)) {
479 umin = minRem;
480 umax = maxRem;
481 }
482 }
483 }
484 }
485 return ConstantIntRanges::fromUnsigned(umin, umax);
486}
487
488//===----------------------------------------------------------------------===//
489// Max and min (MaxS, MaxU, MinS, MinU)
490//===----------------------------------------------------------------------===//
491
494 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
495
496 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
497 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
498 return ConstantIntRanges::fromSigned(smin, smax);
499}
500
503 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
504
505 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
506 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
507 return ConstantIntRanges::fromUnsigned(umin, umax);
508}
509
512 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
513
514 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
515 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
516 return ConstantIntRanges::fromSigned(smin, smax);
517}
518
521 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
522
523 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
524 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
525 return ConstantIntRanges::fromUnsigned(umin, umax);
526}
527
528//===----------------------------------------------------------------------===//
529// Bitwise operators (And, Or, Xor)
530//===----------------------------------------------------------------------===//
531
532/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
533/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
534/// that both bonuds have in common. This gives us a consertive approximation
535/// for what values can be passed to bitwise operations.
536static std::tuple<APInt, APInt>
538 APInt leftVal = bound.umin(), rightVal = bound.umax();
539 unsigned bitwidth = leftVal.getBitWidth();
540 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
541 leftVal.clearLowBits(differingBits);
542 rightVal.setLowBits(differingBits);
543 return std::make_tuple(std::move(leftVal), std::move(rightVal));
544}
545
548 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
549 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
550 auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
551 return a & b;
552 };
553 return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
554 /*isSigned=*/false);
555}
556
559 auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
560 auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
561 auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
562 return a | b;
563 };
564 return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
565 /*isSigned=*/false);
566}
567
568/// Get bitmask of all bits which can change while iterating in
569/// [bound.umin(), bound.umax()].
570static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
571 APInt leftVal = bound.umin(), rightVal = bound.umax();
572 unsigned bitwidth = leftVal.getBitWidth();
573 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
574 return APInt::getLowBitsSet(bitwidth, differingBits);
575}
576
579 // Construct mask of varying bits for both ranges, xor values and then replace
580 // masked bits with 0s and 1s to get min and max values respectively.
581 ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
583 APInt res = lhs.umin() ^ rhs.umin();
584 APInt min = res & ~mask;
585 APInt max = res | mask;
587}
588
589//===----------------------------------------------------------------------===//
590// Shifts (Shl, ShrS, ShrU)
591//===----------------------------------------------------------------------===//
592
595 OverflowFlags ovfFlags) {
596 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
597 const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax();
598
599 // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with
600 // 2^rhs.
601 ConstArithStdFn ushl = [=](const APInt &l,
602 const APInt &r) -> std::optional<APInt> {
603 bool overflowed = false;
604 APInt result = any(ovfFlags & OverflowFlags::Nuw)
605 ? l.ushl_sat(r)
606 : l.ushl_ov(r, overflowed);
607 return overflowed ? std::optional<APInt>() : result;
608 };
609 ConstArithStdFn sshl = [=](const APInt &l,
610 const APInt &r) -> std::optional<APInt> {
611 bool overflowed = false;
612 APInt result = any(ovfFlags & OverflowFlags::Nsw)
613 ? l.sshl_sat(r)
614 : l.sshl_ov(r, overflowed);
615 return overflowed ? std::optional<APInt>() : result;
616 };
617
618 ConstantIntRanges urange =
619 minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax},
620 /*isSigned=*/false);
621 ConstantIntRanges srange =
622 minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax},
623 /*isSigned=*/true);
624 return urange.intersection(srange);
625}
626
629 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
630
631 auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
632 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r);
633 };
634
635 return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
636 /*isSigned=*/true);
637}
638
641 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
642
643 auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> {
644 return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r);
645 };
646 return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
647 /*isSigned=*/false);
648}
649
650//===----------------------------------------------------------------------===//
651// Comparisons (Cmp)
652//===----------------------------------------------------------------------===//
653
679
681 const ConstantIntRanges &lhs,
682 const ConstantIntRanges &rhs) {
683 switch (pred) {
685 return lhs.smax().sle(rhs.smin());
687 return lhs.smax().slt(rhs.smin());
689 return lhs.umax().ule(rhs.umin());
691 return lhs.umax().ult(rhs.umin());
693 return lhs.smin().sge(rhs.smax());
695 return lhs.smin().sgt(rhs.smax());
697 return lhs.umin().uge(rhs.umax());
699 return lhs.umin().ugt(rhs.umax());
701 std::optional<APInt> lhsConst = lhs.getConstantValue();
702 std::optional<APInt> rhsConst = rhs.getConstantValue();
703 return lhsConst && rhsConst && lhsConst == rhsConst;
704 }
706 // While equality requires that there is an interpration of the preceeding
707 // computations that produces equal constants, whether that be signed or
708 // unsigned, statically determining inequality requires that neither
709 // interpretation produce potentially overlapping ranges.
714 return sne && une;
715 }
716 }
717 return false;
718}
719
721 const ConstantIntRanges &lhs,
722 const ConstantIntRanges &rhs) {
723 if (isStaticallyTrue(pred, lhs, rhs))
724 return true;
726 return false;
727 return std::nullopt;
728}
729
730//===----------------------------------------------------------------------===//
731// Shaped type dimension accessors / ShapedDimOpInterface
732//===----------------------------------------------------------------------===//
733
736 const IntegerValueRange &maybeDim) {
737 unsigned width =
738 ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
739 APInt zero = APInt::getZero(width);
740 APInt typeMax = APInt::getSignedMaxValue(width);
741
742 auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
743 if (!shapedTy.hasRank())
744 return ConstantIntRanges::fromSigned(zero, typeMax);
745
746 int64_t rank = shapedTy.getRank();
747 int64_t minDim = 0;
748 int64_t maxDim = rank - 1;
749 if (!maybeDim.isUninitialized()) {
750 const ConstantIntRanges &dim = maybeDim.getValue();
751 minDim = std::max(minDim, dim.smin().getSExtValue());
752 maxDim = std::min(maxDim, dim.smax().getSExtValue());
753 }
754
755 std::optional<ConstantIntRanges> result;
756 auto joinResult = [&](const ConstantIntRanges &thisResult) {
757 if (!result.has_value())
758 result = thisResult;
759 else
760 result = result->rangeUnion(thisResult);
761 };
762 for (int64_t i = minDim; i <= maxDim; ++i) {
763 int64_t length = shapedTy.getDimSize(i);
764
765 if (ShapedType::isDynamic(length))
766 joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
767 else
768 joinResult(ConstantIntRanges::constant(APInt(width, length)));
769 }
770 return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
771}
772
773//===----------------------------------------------------------------------===//
774// Affine expression inference
775//===----------------------------------------------------------------------===//
776
778 unsigned width = val.smin().getBitWidth();
779 APInt one(width, 1);
780 APInt clampedUMin = val.umin().ult(one) ? one : val.umin();
781 APInt clampedSMin = val.smin().slt(one) ? one : val.smin();
782 return ConstantIntRanges::fromUnsigned(clampedUMin, val.umax())
783 .intersection(ConstantIntRanges::fromSigned(clampedSMin, val.smax()));
784}
785
789 ArrayRef<ConstantIntRanges> symbolRanges) {
790 switch (expr.getKind()) {
792 auto constExpr = cast<AffineConstantExpr>(expr);
793 APInt value(indexMaxWidth, constExpr.getValue(), /*isSigned=*/true);
794 return ConstantIntRanges::constant(value);
795 }
797 auto dimExpr = cast<AffineDimExpr>(expr);
798 unsigned pos = dimExpr.getPosition();
799 assert(pos < dimRanges.size() && "Dimension index out of bounds");
800 return dimRanges[pos];
801 }
803 auto symbolExpr = cast<AffineSymbolExpr>(expr);
804 unsigned pos = symbolExpr.getPosition();
805 assert(pos < symbolRanges.size() && "Symbol index out of bounds");
806 return symbolRanges[pos];
807 }
808 case AffineExprKind::Add: {
809 auto binExpr = cast<AffineBinaryOpExpr>(expr);
811 inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
813 inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
815 }
816 case AffineExprKind::Mul: {
817 auto binExpr = cast<AffineBinaryOpExpr>(expr);
819 inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
821 inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
823 }
824 case AffineExprKind::Mod: {
825 auto binExpr = cast<AffineBinaryOpExpr>(expr);
827 inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
829 inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
830 // Affine mod is Euclidean modulo: result is always in [0, rhs-1].
831 // This assumes RHS is positive (enforced by affine expr semantics).
832 const APInt &lhsMin = lhs.smin();
833 const APInt &lhsMax = lhs.smax();
834 const APInt &rhsMin = rhs.smin();
835 const APInt &rhsMax = rhs.smax();
836 unsigned width = rhsMin.getBitWidth();
837
838 // Guard against division by zero.
839 if (rhsMax.isZero())
840 return ConstantIntRanges::maxRange(width);
841
842 APInt zero = APInt::getZero(width);
843
844 // For Euclidean mod, result is in [0, max(rhs)-1].
845 APInt umin = zero;
846 APInt umax = rhsMax - 1;
847
848 // Special case: if dividend is already in [0, min(rhs)), result equals
849 // dividend. We use rhsMin to ensure this is safe for all possible divisor
850 // values.
851 if (rhsMin.isStrictlyPositive() && lhsMin.isNonNegative() &&
852 lhsMax.ult(rhsMin)) {
853 umin = lhsMin;
854 umax = lhsMax;
855 }
856 // Special case: sweeping out a contiguous range with constant divisor.
857 // Only applies when dividend is non-negative to ensure result range is
858 // contiguous.
859 else if (rhsMin == rhsMax && lhsMin.isNonNegative() &&
860 (lhsMax - lhsMin).ult(rhsMax)) {
861 // For non-negative dividends, Euclidean mod is same as unsigned
862 // remainder.
863 umin = lhsMin.urem(rhsMax);
864 umax = lhsMax.urem(rhsMax);
865 // Result should be contiguous since we're not wrapping around.
866 assert(umin.ule(umax) &&
867 "Range should be contiguous for non-negative dividend");
868 }
869
870 return ConstantIntRanges::fromUnsigned(umin, umax);
871 }
873 auto binExpr = cast<AffineBinaryOpExpr>(expr);
875 inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
877 inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
878 // Affine floordiv requires strictly positive divisor (> 0).
879 // Clamp divisor lower bound to 1 for tighter range inference.
881 return inferFloorDivS({lhs, clampedRhs});
882 }
884 auto binExpr = cast<AffineBinaryOpExpr>(expr);
886 inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
888 inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
889 // Affine ceildiv requires strictly positive divisor (> 0).
890 // Clamp divisor lower bound to 1 for tighter range inference.
892 return inferCeilDivS({lhs, clampedRhs});
893 }
894 }
895 llvm_unreachable("unknown affine expression kind");
896}
lhs
static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
static bool isStaticallyTrue(intrange::CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred)
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef< APInt > lhs, ArrayRef< APInt > rhs, bool isSigned)
Compute the minimum and maximum of (op(l, r) for l in lhs for r in rhs), ignoring unbounded values.
std::function< std::optional< APInt >(const APInt &, const APInt &)> ConstArithStdFn
static std::tuple< APInt, APInt > widenBitwiseBounds(const ConstantIntRanges &bound)
"Widen" bounds - if 0bvvvvv?
static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup)
function_ref< std::optional< APInt >(const APInt &, const APInt &)> ConstArithFn
Function that evaluates the result of doing something on arithmetic constants and returns std::nullop...
static ConstantIntRanges clampToPositive(const ConstantIntRanges &val)
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned)
Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, If either computation overflows,...
static APInt getVaryingBitsMask(const ConstantIntRanges &bound)
Get bitmask of all bits which can change while iterating in [bound.umin(), bound.umax()].
function_ref< std::optional< APInt >( const APInt &lhs, const APInt &rhs, const APInt &result)> DivisionFixupFn
Fix up division results (ex.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
This lattice value represents the integer range of an SSA value.
const ConstantIntRanges & getValue() const
Get the known integer value range.
bool isUninitialized() const
Whether the range is uninitialized.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferAffineExpr(AffineExpr expr, ArrayRef< ConstantIntRanges > dimRanges, ArrayRef< ConstantIntRanges > symbolRanges)
Infer the integer range for an affine expression given ranges for its dimensions and symbols.
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> InferRangeFn
Function that performs inference on an array of ConstantIntRanges, abstracted away here to permit wri...
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ Constant
Constant integer.
Definition AffineExpr.h:57
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152