MLIR 23.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
13#include <optional>
14
15#define DEBUG_TYPE "int-range-analysis"
16
17using namespace mlir;
18using namespace mlir::arith;
19using namespace mlir::intrange;
20
22convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
24 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nsw))
26 if (bitEnumContainsAny(flags, arith::IntegerOverflowFlags::nuw))
28 return retFlags;
29}
30
31//===----------------------------------------------------------------------===//
32// ConstantOp
33//===----------------------------------------------------------------------===//
34
35void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
36 SetIntRangeFn setResultRange) {
37 if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
38 const APInt &value = scalarCstAttr.getValue();
39 setResultRange(getResult(), ConstantIntRanges::constant(value));
40 return;
41 }
42 if (auto arrayCstAttr =
43 llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
44 if (arrayCstAttr.empty())
45 return;
46
47 if (arrayCstAttr.isSplat()) {
48 setResultRange(getResult(), ConstantIntRanges::constant(
49 arrayCstAttr.getSplatValue<APInt>()));
50 return;
51 }
52
53 std::optional<ConstantIntRanges> result;
54 for (const APInt &val : arrayCstAttr) {
55 auto range = ConstantIntRanges::constant(val);
56 result = (result ? result->rangeUnion(range) : range);
57 }
58
59 setResultRange(getResult(), *result);
60 return;
61 }
62}
63
64//===----------------------------------------------------------------------===//
65// AddIOp
66//===----------------------------------------------------------------------===//
67
68void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
69 SetIntRangeFn setResultRange) {
70 setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
71 getOverflowFlags())));
72}
73
74//===----------------------------------------------------------------------===//
75// SubIOp
76//===----------------------------------------------------------------------===//
77
78void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
79 SetIntRangeFn setResultRange) {
80 setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
81 getOverflowFlags())));
82}
83
84//===----------------------------------------------------------------------===//
85// MulIOp
86//===----------------------------------------------------------------------===//
87
88void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
89 SetIntRangeFn setResultRange) {
90 setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
91 getOverflowFlags())));
92}
93
94//===----------------------------------------------------------------------===//
95// DivUIOp
96//===----------------------------------------------------------------------===//
97
98void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
99 SetIntRangeFn setResultRange) {
100 setResultRange(getResult(), inferDivU(argRanges));
101}
102
103//===----------------------------------------------------------------------===//
104// DivSIOp
105//===----------------------------------------------------------------------===//
106
107void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
108 SetIntRangeFn setResultRange) {
109 setResultRange(getResult(), inferDivS(argRanges));
110}
111
112//===----------------------------------------------------------------------===//
113// CeilDivUIOp
114//===----------------------------------------------------------------------===//
115
116void arith::CeilDivUIOp::inferResultRanges(
117 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
118 setResultRange(getResult(), inferCeilDivU(argRanges));
119}
120
121//===----------------------------------------------------------------------===//
122// CeilDivSIOp
123//===----------------------------------------------------------------------===//
124
125void arith::CeilDivSIOp::inferResultRanges(
126 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
127 setResultRange(getResult(), inferCeilDivS(argRanges));
128}
129
130//===----------------------------------------------------------------------===//
131// FloorDivSIOp
132//===----------------------------------------------------------------------===//
133
134void arith::FloorDivSIOp::inferResultRanges(
135 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
136 return setResultRange(getResult(), inferFloorDivS(argRanges));
137}
138
139//===----------------------------------------------------------------------===//
140// RemUIOp
141//===----------------------------------------------------------------------===//
142
143void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
144 SetIntRangeFn setResultRange) {
145 setResultRange(getResult(), inferRemU(argRanges));
146}
147
148//===----------------------------------------------------------------------===//
149// RemSIOp
150//===----------------------------------------------------------------------===//
151
152void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
153 SetIntRangeFn setResultRange) {
154 setResultRange(getResult(), inferRemS(argRanges));
155}
156
157//===----------------------------------------------------------------------===//
158// AndIOp
159//===----------------------------------------------------------------------===//
160
161void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
162 SetIntRangeFn setResultRange) {
163 setResultRange(getResult(), inferAnd(argRanges));
164}
165
166//===----------------------------------------------------------------------===//
167// OrIOp
168//===----------------------------------------------------------------------===//
169
170void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
171 SetIntRangeFn setResultRange) {
172 setResultRange(getResult(), inferOr(argRanges));
173}
174
175//===----------------------------------------------------------------------===//
176// XOrIOp
177//===----------------------------------------------------------------------===//
178
179void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
180 SetIntRangeFn setResultRange) {
181 setResultRange(getResult(), inferXor(argRanges));
182}
183
184//===----------------------------------------------------------------------===//
185// MaxSIOp
186//===----------------------------------------------------------------------===//
187
188void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
189 SetIntRangeFn setResultRange) {
190 setResultRange(getResult(), inferMaxS(argRanges));
191}
192
193//===----------------------------------------------------------------------===//
194// MaxUIOp
195//===----------------------------------------------------------------------===//
196
197void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
198 SetIntRangeFn setResultRange) {
199 setResultRange(getResult(), inferMaxU(argRanges));
200}
201
202//===----------------------------------------------------------------------===//
203// MinSIOp
204//===----------------------------------------------------------------------===//
205
206void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
207 SetIntRangeFn setResultRange) {
208 setResultRange(getResult(), inferMinS(argRanges));
209}
210
211//===----------------------------------------------------------------------===//
212// MinUIOp
213//===----------------------------------------------------------------------===//
214
215void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
216 SetIntRangeFn setResultRange) {
217 setResultRange(getResult(), inferMinU(argRanges));
218}
219
220//===----------------------------------------------------------------------===//
221// ExtUIOp
222//===----------------------------------------------------------------------===//
223
224void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
225 SetIntRangeFn setResultRange) {
226 unsigned destWidth =
228 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
229}
230
231//===----------------------------------------------------------------------===//
232// ExtSIOp
233//===----------------------------------------------------------------------===//
234
235void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
236 SetIntRangeFn setResultRange) {
237 unsigned destWidth =
239 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
240}
241
242//===----------------------------------------------------------------------===//
243// TruncIOp
244//===----------------------------------------------------------------------===//
245
246void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
247 SetIntRangeFn setResultRange) {
248 unsigned destWidth =
250 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
251}
252
253//===----------------------------------------------------------------------===//
254// IndexCastOp
255//===----------------------------------------------------------------------===//
256
257void arith::IndexCastOp::inferResultRanges(
258 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
259 Type sourceType = getOperand().getType();
260 Type destType = getResult().getType();
261 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
262 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
263
264 if (srcWidth < destWidth)
265 setResultRange(getResult(), extSIRange(argRanges[0], destWidth));
266 else if (srcWidth > destWidth)
267 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
268 else
269 setResultRange(getResult(), argRanges[0]);
270}
271
272//===----------------------------------------------------------------------===//
273// IndexCastUIOp
274//===----------------------------------------------------------------------===//
275
276void arith::IndexCastUIOp::inferResultRanges(
277 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
278 Type sourceType = getOperand().getType();
279 Type destType = getResult().getType();
280 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
281 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
282
283 if (srcWidth < destWidth)
284 setResultRange(getResult(), extUIRange(argRanges[0], destWidth));
285 else if (srcWidth > destWidth)
286 setResultRange(getResult(), truncRange(argRanges[0], destWidth));
287 else
288 setResultRange(getResult(), argRanges[0]);
289}
290
291//===----------------------------------------------------------------------===//
292// CmpIOp
293//===----------------------------------------------------------------------===//
294
295void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
296 SetIntRangeFn setResultRange) {
297 arith::CmpIPredicate arithPred = getPredicate();
298 intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(arithPred);
299 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
300
301 APInt min = APInt::getZero(1);
302 APInt max = APInt::getAllOnes(1);
303
304 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
305 if (truthValue.has_value() && *truthValue)
306 min = max;
307 else if (truthValue.has_value() && !(*truthValue))
308 max = min;
309
310 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
311}
312
313//===----------------------------------------------------------------------===//
314// SelectOp
315//===----------------------------------------------------------------------===//
316
317void arith::SelectOp::inferResultRangesFromOptional(
318 ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
319 std::optional<APInt> mbCondVal =
320 argRanges[0].isUninitialized()
321 ? std::nullopt
322 : argRanges[0].getValue().getConstantValue();
323
324 const IntegerValueRange &trueCase = argRanges[1];
325 const IntegerValueRange &falseCase = argRanges[2];
326
327 if (mbCondVal) {
328 if (mbCondVal->isZero())
329 setResultRange(getResult(), falseCase);
330 else
331 setResultRange(getResult(), trueCase);
332 return;
333 }
334
335 // When one of the ranges is uninitialized, set the whole range to max
336 // otherwise the result will ignore the uninitialized range.
337 if (trueCase.isUninitialized() || falseCase.isUninitialized())
338 setResultRange(getResult(), IntegerValueRange::getMaxRange(getResult()));
339 else
340 setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
341}
342
343//===----------------------------------------------------------------------===//
344// ShLIOp
345//===----------------------------------------------------------------------===//
346
347void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
348 SetIntRangeFn setResultRange) {
349 setResultRange(getResult(), inferShl(argRanges, convertArithOverflowFlags(
350 getOverflowFlags())));
351}
352
353//===----------------------------------------------------------------------===//
354// ShRUIOp
355//===----------------------------------------------------------------------===//
356
357void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
358 SetIntRangeFn setResultRange) {
359 setResultRange(getResult(), inferShrU(argRanges));
360}
361
362//===----------------------------------------------------------------------===//
363// ShRSIOp
364//===----------------------------------------------------------------------===//
365
366void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
367 SetIntRangeFn setResultRange) {
368 setResultRange(getResult(), inferShrS(argRanges));
369}
static intrange::OverflowFlags convertArithOverflowFlags(arith::IntegerOverflowFlags flags)
lhs
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
bool isUninitialized() const
Whether the range is uninitialized.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
static IntegerValueRange join(const IntegerValueRange &lhs, const IntegerValueRange &rhs)
Compute the least upper bound of two ranges.
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307