MLIR 22.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
13#include <optional>
14
15#define DEBUG_TYPE "int-range-analysis"
16
17using namespace mlir;
18using namespace mlir::arith;
19using namespace mlir::intrange;
20
22convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
24 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
26 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
28 return retFlags;
29}
30
31//===----------------------------------------------------------------------===//
32// ConstantOp
33//===----------------------------------------------------------------------===//
34
35void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
36 SetIntRangeFn setResultRange) {
37 if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
38 const APInt &value = scalarCstAttr.getValue();
39 setResultRange(getResult(), ConstantIntRanges::constant(value));
40 return;
41 }
42 if (auto arrayCstAttr =
43 llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
44 if (arrayCstAttr.isSplat()) {
45 setResultRange(getResult(), ConstantIntRanges::constant(
46 arrayCstAttr.getSplatValue<APInt>()));
47 return;
48 }
49
50 std::optional<ConstantIntRanges> result;
51 for (const APInt &val : arrayCstAttr) {
52 auto range = ConstantIntRanges::constant(val);
53 result = (result ? result->rangeUnion(range) : range);
54 }
55
56 assert(result && "Zero-sized vectors are not allowed");
57 setResultRange(getResult(), *result);
58 return;
59 }
60}
61
62//===----------------------------------------------------------------------===//
63// AddIOp
64//===----------------------------------------------------------------------===//
65
66void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
67 SetIntRangeFn setResultRange) {
68 setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
69 getOverflowFlags())));
70}
71
72//===----------------------------------------------------------------------===//
73// SubIOp
74//===----------------------------------------------------------------------===//
75
76void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
77 SetIntRangeFn setResultRange) {
78 setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
79 getOverflowFlags())));
80}
81
82//===----------------------------------------------------------------------===//
83// MulIOp
84//===----------------------------------------------------------------------===//
85
86void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87 SetIntRangeFn setResultRange) {
88 setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
89 getOverflowFlags())));
90}
91
92//===----------------------------------------------------------------------===//
93// DivUIOp
94//===----------------------------------------------------------------------===//
95
96void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
97 SetIntRangeFn setResultRange) {
98 setResultRange(getResult(), inferDivU(argRanges));
99}
100
101//===----------------------------------------------------------------------===//
102// DivSIOp
103//===----------------------------------------------------------------------===//
104
105void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106 SetIntRangeFn setResultRange) {
107 setResultRange(getResult(), inferDivS(argRanges));
108}
109
110//===----------------------------------------------------------------------===//
111// CeilDivUIOp
112//===----------------------------------------------------------------------===//
113
114void arith::CeilDivUIOp::inferResultRanges(
115 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
116 setResultRange(getResult(), inferCeilDivU(argRanges));
117}
118
119//===----------------------------------------------------------------------===//
120// CeilDivSIOp
121//===----------------------------------------------------------------------===//
122
123void arith::CeilDivSIOp::inferResultRanges(
124 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
125 setResultRange(getResult(), inferCeilDivS(argRanges));
126}
127
128//===----------------------------------------------------------------------===//
129// FloorDivSIOp
130//===----------------------------------------------------------------------===//
131
132void arith::FloorDivSIOp::inferResultRanges(
133 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
134 return setResultRange(getResult(), inferFloorDivS(argRanges));
135}
136
137//===----------------------------------------------------------------------===//
138// RemUIOp
139//===----------------------------------------------------------------------===//
140
141void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142 SetIntRangeFn setResultRange) {
143 setResultRange(getResult(), inferRemU(argRanges));
144}
145
146//===----------------------------------------------------------------------===//
147// RemSIOp
148//===----------------------------------------------------------------------===//
149
150void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
151 SetIntRangeFn setResultRange) {
152 setResultRange(getResult(), inferRemS(argRanges));
153}
154
155//===----------------------------------------------------------------------===//
156// AndIOp
157//===----------------------------------------------------------------------===//
158
159void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160 SetIntRangeFn setResultRange) {
161 setResultRange(getResult(), inferAnd(argRanges));
162}
163
164//===----------------------------------------------------------------------===//
165// OrIOp
166//===----------------------------------------------------------------------===//
167
168void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
169 SetIntRangeFn setResultRange) {
170 setResultRange(getResult(), inferOr(argRanges));
171}
172
173//===----------------------------------------------------------------------===//
174// XOrIOp
175//===----------------------------------------------------------------------===//
176
177void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
178 SetIntRangeFn setResultRange) {
179 setResultRange(getResult(), inferXor(argRanges));
180}
181
182//===----------------------------------------------------------------------===//
183// MaxSIOp
184//===----------------------------------------------------------------------===//
185
186void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
187 SetIntRangeFn setResultRange) {
188 setResultRange(getResult(), inferMaxS(argRanges));
189}
190
191//===----------------------------------------------------------------------===//
192// MaxUIOp
193//===----------------------------------------------------------------------===//
194
195void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
196 SetIntRangeFn setResultRange) {
197 setResultRange(getResult(), inferMaxU(argRanges));
198}
199
200//===----------------------------------------------------------------------===//
201// MinSIOp
202//===----------------------------------------------------------------------===//
203
204void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
205 SetIntRangeFn setResultRange) {
206 setResultRange(getResult(), inferMinS(argRanges));
207}
208
209//===----------------------------------------------------------------------===//
210// MinUIOp
211//===----------------------------------------------------------------------===//
212
213void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
214 SetIntRangeFn setResultRange) {
215 setResultRange(getResult(), inferMinU(argRanges));
216}
217
218//===----------------------------------------------------------------------===//
219// ExtUIOp
220//===----------------------------------------------------------------------===//
221
222void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
223 SetIntRangeFn setResultRange) {
224 unsigned destWidth =
226 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
227}
228
229//===----------------------------------------------------------------------===//
230// ExtSIOp
231//===----------------------------------------------------------------------===//
232
233void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
234 SetIntRangeFn setResultRange) {
235 unsigned destWidth =
237 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
238}
239
240//===----------------------------------------------------------------------===//
241// TruncIOp
242//===----------------------------------------------------------------------===//
243
244void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
245 SetIntRangeFn setResultRange) {
246 unsigned destWidth =
248 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
249}
250
251//===----------------------------------------------------------------------===//
252// IndexCastOp
253//===----------------------------------------------------------------------===//
254
255void arith::IndexCastOp::inferResultRanges(
256 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
257 Type sourceType = getOperand().getType();
258 Type destType = getResult().getType();
259 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
260 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
261
262 if (srcWidth < destWidth)
263 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
264 else if (srcWidth > destWidth)
265 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
266 else
267 setResultRange(getResult(), argRanges[0]);
268}
269
270//===----------------------------------------------------------------------===//
271// IndexCastUIOp
272//===----------------------------------------------------------------------===//
273
274void arith::IndexCastUIOp::inferResultRanges(
275 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
276 Type sourceType = getOperand().getType();
277 Type destType = getResult().getType();
278 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
279 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
280
281 if (srcWidth < destWidth)
282 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
283 else if (srcWidth > destWidth)
284 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
285 else
286 setResultRange(getResult(), argRanges[0]);
287}
288
289//===----------------------------------------------------------------------===//
290// CmpIOp
291//===----------------------------------------------------------------------===//
292
293void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
294 SetIntRangeFn setResultRange) {
295 arith::CmpIPredicate arithPred = getPredicate();
296 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
297 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
298
299 APInt min = APInt::getZero(1);
300 APInt max = APInt::getAllOnes(1);
301
302 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
303 if (truthValue.has_value() && *truthValue)
304 min = max;
305 else if (truthValue.has_value() && !(*truthValue))
306 max = min;
307
308 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
309}
310
311//===----------------------------------------------------------------------===//
312// SelectOp
313//===----------------------------------------------------------------------===//
314
315void arith::SelectOp::inferResultRangesFromOptional(
316 ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
317 std::optional<APInt> mbCondVal =
318 argRanges[0].isUninitialized()
319 ? std::nullopt
320 : argRanges[0].getValue().getConstantValue();
321
322 const IntegerValueRange &trueCase = argRanges[1];
323 const IntegerValueRange &falseCase = argRanges[2];
324
325 if (mbCondVal) {
326 if (mbCondVal->isZero())
327 setResultRange(getResult(), falseCase);
328 else
329 setResultRange(getResult(), trueCase);
330 return;
331 }
332 setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
333}
334
335//===----------------------------------------------------------------------===//
336// ShLIOp
337//===----------------------------------------------------------------------===//
338
339void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
340 SetIntRangeFn setResultRange) {
341 setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
342 getOverflowFlags())));
343}
344
345//===----------------------------------------------------------------------===//
346// ShRUIOp
347//===----------------------------------------------------------------------===//
348
349void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
350 SetIntRangeFn setResultRange) {
351 setResultRange(getResult(), inferShrU(argRanges));
352}
353
354//===----------------------------------------------------------------------===//
355// ShRSIOp
356//===----------------------------------------------------------------------===//
357
358void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
359 SetIntRangeFn setResultRange) {
360 setResultRange(getResult(), inferShrS(argRanges));
361}
static intrange::OverflowFlags convertArithOverflowFlags(arith::IntegerOverflowFlags flags)
lhs
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
static IntegerValueRange join(const IntegerValueRange &lhs, const IntegerValueRange &rhs)
Compute the least upper bound of two ranges.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304