MLIR  22.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 #include <optional>
14 
15 #define DEBUG_TYPE "int-range-analysis"
16 
17 using namespace mlir;
18 using namespace mlir::index;
19 using namespace mlir::intrange;
20 
21 //===----------------------------------------------------------------------===//
22 // Constants
23 //===----------------------------------------------------------------------===//
24 
25 void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
26  SetIntRangeFn setResultRange) {
27  const APInt &value = getValue();
28  setResultRange(getResult(), ConstantIntRanges::constant(value));
29 }
30 
31 void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
32  SetIntRangeFn setResultRange) {
33  bool value = getValue();
34  APInt asInt(/*numBits=*/1, value);
35  setResultRange(getResult(), ConstantIntRanges::constant(asInt));
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // Arithmec operations. All of these operations will have their results inferred
40 // using both the 64-bit values and truncated 32-bit values of their inputs,
41 // with the results being the union of those inferences, except where the
42 // truncation of the 64-bit result is equal to the 32-bit result (at which time
43 // we take the 64-bit result).
44 //===----------------------------------------------------------------------===//
45 
46 // Some arithmetic inference functions allow specifying special overflow / wrap
47 // behavior. We do not require this for the IndexOps and use this helper to call
48 // the inference function without any `OverflowFlags`.
51  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
52  return inferWithOvfFn(argRanges, OverflowFlags::None);
53  };
54 }
55 
56 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
57  SetIntRangeFn setResultRange) {
58  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
59  argRanges, CmpMode::Both));
60 }
61 
62 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
63  SetIntRangeFn setResultRange) {
64  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
65  argRanges, CmpMode::Both));
66 }
67 
68 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
69  SetIntRangeFn setResultRange) {
70  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
71  argRanges, CmpMode::Both));
72 }
73 
74 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
75  SetIntRangeFn setResultRange) {
76  setResultRange(getResult(),
77  inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
78 }
79 
80 void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
81  SetIntRangeFn setResultRange) {
82  setResultRange(getResult(),
84 }
85 
86 void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87  SetIntRangeFn setResultRange) {
88  setResultRange(getResult(),
89  inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
90 }
91 
92 void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
93  SetIntRangeFn setResultRange) {
94  setResultRange(getResult(),
96 }
97 
98 void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
99  SetIntRangeFn setResultRange) {
100  return setResultRange(
101  getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
102 }
103 
104 void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
105  SetIntRangeFn setResultRange) {
106  setResultRange(getResult(),
107  inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
108 }
109 
110 void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
111  SetIntRangeFn setResultRange) {
112  setResultRange(getResult(),
113  inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
114 }
115 
116 void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
117  SetIntRangeFn setResultRange) {
118  setResultRange(getResult(),
119  inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
120 }
121 
122 void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
123  SetIntRangeFn setResultRange) {
124  setResultRange(getResult(),
125  inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
126 }
127 
128 void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129  SetIntRangeFn setResultRange) {
130  setResultRange(getResult(),
131  inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
132 }
133 
134 void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
135  SetIntRangeFn setResultRange) {
136  setResultRange(getResult(),
137  inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
138 }
139 
140 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
141  SetIntRangeFn setResultRange) {
142  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
143  argRanges, CmpMode::Both));
144 }
145 
146 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
147  SetIntRangeFn setResultRange) {
148  setResultRange(getResult(),
149  inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
150 }
151 
152 void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
153  SetIntRangeFn setResultRange) {
154  setResultRange(getResult(),
155  inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
156 }
157 
158 void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
159  SetIntRangeFn setResultRange) {
160  setResultRange(getResult(),
161  inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
162 }
163 
164 void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
165  SetIntRangeFn setResultRange) {
166  setResultRange(getResult(),
167  inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
168 }
169 
170 void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
171  SetIntRangeFn setResultRange) {
172  setResultRange(getResult(),
173  inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Casts
178 //===----------------------------------------------------------------------===//
179 
181  unsigned srcWidth, unsigned destWidth,
182  bool isSigned) {
183  if (srcWidth < destWidth)
184  return isSigned ? extSIRange(range, destWidth)
185  : extUIRange(range, destWidth);
186  if (srcWidth > destWidth)
187  return truncRange(range, destWidth);
188  return range;
189 }
190 
191 // When casting to `index`, we will take the union of the possible fixed-width
192 // casts.
194  Type sourceType, Type destType,
195  bool isSigned) {
196  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
197  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
198  if (sourceType.isIndex())
199  return makeLikeDest(range, srcWidth, destWidth, isSigned);
200  // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
201  ConstantIntRanges storageRange =
202  makeLikeDest(range, srcWidth, destWidth, isSigned);
203  ConstantIntRanges minWidthRange =
204  makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
205  ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
206  ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
207  return ret;
208 }
209 
210 void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
211  SetIntRangeFn setResultRange) {
212  Type sourceType = getOperand().getType();
213  Type destType = getResult().getType();
214  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
215  /*isSigned=*/true));
216 }
217 
218 void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
219  SetIntRangeFn setResultRange) {
220  Type sourceType = getOperand().getType();
221  Type destType = getResult().getType();
222  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
223  /*isSigned=*/false));
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // CmpOp
228 //===----------------------------------------------------------------------===//
229 
230 void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
231  SetIntRangeFn setResultRange) {
232  index::IndexCmpPredicate indexPred = getPred();
233  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
234  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
235 
236  APInt min = APInt::getZero(1);
237  APInt max = APInt::getAllOnes(1);
238 
239  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
240 
241  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
242  rhsTrunc = truncRange(rhs, indexMinWidth);
243  std::optional<bool> truthValue32 =
244  intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
245 
246  if (truthValue64 == truthValue32) {
247  if (truthValue64.has_value() && *truthValue64)
248  min = max;
249  else if (truthValue64.has_value() && !(*truthValue64))
250  max = min;
251  }
252  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
257 //===----------------------------------------------------------------------===//
258 
259 void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
260  SetIntRangeFn setResultRange) {
261  unsigned storageWidth =
263  APInt min(/*numBits=*/storageWidth, indexMinWidth);
264  APInt max(/*numBits=*/storageWidth, indexMaxWidth);
265  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
266 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn)
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, unsigned srcWidth, unsigned destWidth, bool isSigned)
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, Type sourceType, Type destType, bool isSigned)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304