MLIR 22.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
13#include <optional>
14
15#define DEBUG_TYPE "int-range-analysis"
16
17using namespace mlir;
18using namespace mlir::index;
19using namespace mlir::intrange;
20
21//===----------------------------------------------------------------------===//
22// Constants
23//===----------------------------------------------------------------------===//
24
25void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
26 SetIntRangeFn setResultRange) {
27 const APInt &value = getValue();
28 setResultRange(getResult(), ConstantIntRanges::constant(value));
29}
30
31void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
32 SetIntRangeFn setResultRange) {
33 bool value = getValue();
34 APInt asInt(/*numBits=*/1, value);
35 setResultRange(getResult(), ConstantIntRanges::constant(asInt));
36}
37
38//===----------------------------------------------------------------------===//
39// Arithmec operations. All of these operations will have their results inferred
40// using both the 64-bit values and truncated 32-bit values of their inputs,
41// with the results being the union of those inferences, except where the
42// truncation of the 64-bit result is equal to the 32-bit result (at which time
43// we take the 64-bit result).
44//===----------------------------------------------------------------------===//
45
46// Some arithmetic inference functions allow specifying special overflow / wrap
47// behavior. We do not require this for the IndexOps and use this helper to call
48// the inference function without any `OverflowFlags`.
51 return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
52 return inferWithOvfFn(argRanges, OverflowFlags::None);
53 };
54}
55
56void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
57 SetIntRangeFn setResultRange) {
58 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
59 argRanges, CmpMode::Both));
60}
61
62void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
63 SetIntRangeFn setResultRange) {
64 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
65 argRanges, CmpMode::Both));
66}
67
68void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
69 SetIntRangeFn setResultRange) {
70 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
71 argRanges, CmpMode::Both));
72}
73
74void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
75 SetIntRangeFn setResultRange) {
76 setResultRange(getResult(),
78}
79
80void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
81 SetIntRangeFn setResultRange) {
82 setResultRange(getResult(),
84}
85
86void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87 SetIntRangeFn setResultRange) {
88 setResultRange(getResult(),
90}
91
92void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
93 SetIntRangeFn setResultRange) {
94 setResultRange(getResult(),
96}
97
98void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
99 SetIntRangeFn setResultRange) {
100 return setResultRange(
101 getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
102}
103
104void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
105 SetIntRangeFn setResultRange) {
106 setResultRange(getResult(),
108}
109
110void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
111 SetIntRangeFn setResultRange) {
112 setResultRange(getResult(),
114}
115
116void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
117 SetIntRangeFn setResultRange) {
118 setResultRange(getResult(),
120}
121
122void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
123 SetIntRangeFn setResultRange) {
124 setResultRange(getResult(),
126}
127
128void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129 SetIntRangeFn setResultRange) {
130 setResultRange(getResult(),
132}
133
134void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
135 SetIntRangeFn setResultRange) {
136 setResultRange(getResult(),
138}
139
140void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
141 SetIntRangeFn setResultRange) {
142 setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
143 argRanges, CmpMode::Both));
144}
145
146void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
147 SetIntRangeFn setResultRange) {
148 setResultRange(getResult(),
150}
151
152void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
153 SetIntRangeFn setResultRange) {
154 setResultRange(getResult(),
156}
157
158void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
159 SetIntRangeFn setResultRange) {
160 setResultRange(getResult(),
162}
163
164void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
165 SetIntRangeFn setResultRange) {
166 setResultRange(getResult(),
168}
169
170void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
171 SetIntRangeFn setResultRange) {
172 setResultRange(getResult(),
174}
175
176//===----------------------------------------------------------------------===//
177// Casts
178//===----------------------------------------------------------------------===//
179
181 unsigned srcWidth, unsigned destWidth,
182 bool isSigned) {
183 if (srcWidth < destWidth)
184 return isSigned ? extSIRange(range, destWidth)
185 : extUIRange(range, destWidth);
186 if (srcWidth > destWidth)
187 return truncRange(range, destWidth);
188 return range;
189}
190
191// When casting to `index`, we will take the union of the possible fixed-width
192// casts.
194 Type sourceType, Type destType,
195 bool isSigned) {
196 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
197 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
198 if (sourceType.isIndex())
199 return makeLikeDest(range, srcWidth, destWidth, isSigned);
200 // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
201 ConstantIntRanges storageRange =
202 makeLikeDest(range, srcWidth, destWidth, isSigned);
203 ConstantIntRanges minWidthRange =
204 makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
205 ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
206 ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
207 return ret;
208}
209
210void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
211 SetIntRangeFn setResultRange) {
212 Type sourceType = getOperand().getType();
213 Type destType = getResult().getType();
214 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
215 /*isSigned=*/true));
216}
217
218void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
219 SetIntRangeFn setResultRange) {
220 Type sourceType = getOperand().getType();
221 Type destType = getResult().getType();
222 setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
223 /*isSigned=*/false));
224}
225
226//===----------------------------------------------------------------------===//
227// CmpOp
228//===----------------------------------------------------------------------===//
229
230void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
231 SetIntRangeFn setResultRange) {
232 index::IndexCmpPredicate indexPred = getPred();
233 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
234 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
235
236 APInt min = APInt::getZero(1);
237 APInt max = APInt::getAllOnes(1);
238
239 std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
240
242 rhsTrunc = truncRange(rhs, indexMinWidth);
243 std::optional<bool> truthValue32 =
244 intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
245
246 if (truthValue64 == truthValue32) {
247 if (truthValue64.has_value() && *truthValue64)
248 min = max;
249 else if (truthValue64.has_value() && !(*truthValue64))
250 max = min;
251 }
252 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
253}
254
255//===----------------------------------------------------------------------===//
256// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
257//===----------------------------------------------------------------------===//
258
259void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
260 SetIntRangeFn setResultRange) {
261 unsigned storageWidth =
263 APInt min(/*numBits=*/storageWidth, indexMinWidth);
264 APInt max(/*numBits=*/storageWidth, indexMaxWidth);
265 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
266}
lhs
static std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn)
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, unsigned srcWidth, unsigned destWidth, bool isSigned)
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, Type sourceType, Type destType, bool isSigned)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
function_ref< ConstantIntRanges(ArrayRef< ConstantIntRanges >, OverflowFlags)> InferRangeWithOvfFlagsFn
Function that performs inference on an array of ConstantIntRanges while taking special overflow behav...
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304