MLIR  20.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15 
16 #define DEBUG_TYPE "int-range-analysis"
17 
18 using namespace mlir;
19 using namespace mlir::index;
20 using namespace mlir::intrange;
21 
22 //===----------------------------------------------------------------------===//
23 // Constants
24 //===----------------------------------------------------------------------===//
25 
26 void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27  SetIntRangeFn setResultRange) {
28  const APInt &value = getValue();
29  setResultRange(getResult(), ConstantIntRanges::constant(value));
30 }
31 
32 void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
33  SetIntRangeFn setResultRange) {
34  bool value = getValue();
35  APInt asInt(/*numBits=*/1, value);
36  setResultRange(getResult(), ConstantIntRanges::constant(asInt));
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Arithmec operations. All of these operations will have their results inferred
41 // using both the 64-bit values and truncated 32-bit values of their inputs,
42 // with the results being the union of those inferences, except where the
43 // truncation of the 64-bit result is equal to the 32-bit result (at which time
44 // we take the 64-bit result).
45 //===----------------------------------------------------------------------===//
46 
47 // Some arithmetic inference functions allow specifying special overflow / wrap
48 // behavior. We do not require this for the IndexOps and use this helper to call
49 // the inference function without any `OverflowFlags`.
52  return [inferWithOvfFn](ArrayRef<ConstantIntRanges> argRanges) {
53  return inferWithOvfFn(argRanges, OverflowFlags::None);
54  };
55 }
56 
57 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58  SetIntRangeFn setResultRange) {
59  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferAdd),
60  argRanges, CmpMode::Both));
61 }
62 
63 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
64  SetIntRangeFn setResultRange) {
65  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferSub),
66  argRanges, CmpMode::Both));
67 }
68 
69 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
70  SetIntRangeFn setResultRange) {
71  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferMul),
72  argRanges, CmpMode::Both));
73 }
74 
75 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76  SetIntRangeFn setResultRange) {
77  setResultRange(getResult(),
78  inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
79 }
80 
81 void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
82  SetIntRangeFn setResultRange) {
83  setResultRange(getResult(),
85 }
86 
87 void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
88  SetIntRangeFn setResultRange) {
89  setResultRange(getResult(),
90  inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
91 }
92 
93 void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
94  SetIntRangeFn setResultRange) {
95  setResultRange(getResult(),
97 }
98 
99 void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
100  SetIntRangeFn setResultRange) {
101  return setResultRange(
102  getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
103 }
104 
105 void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
106  SetIntRangeFn setResultRange) {
107  setResultRange(getResult(),
108  inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
109 }
110 
111 void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
112  SetIntRangeFn setResultRange) {
113  setResultRange(getResult(),
114  inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
115 }
116 
117 void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
118  SetIntRangeFn setResultRange) {
119  setResultRange(getResult(),
120  inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
121 }
122 
123 void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
124  SetIntRangeFn setResultRange) {
125  setResultRange(getResult(),
126  inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
127 }
128 
129 void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
130  SetIntRangeFn setResultRange) {
131  setResultRange(getResult(),
132  inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
133 }
134 
135 void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
136  SetIntRangeFn setResultRange) {
137  setResultRange(getResult(),
138  inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
139 }
140 
141 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142  SetIntRangeFn setResultRange) {
143  setResultRange(getResult(), inferIndexOp(inferWithoutOverflowFlags(inferShl),
144  argRanges, CmpMode::Both));
145 }
146 
147 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
148  SetIntRangeFn setResultRange) {
149  setResultRange(getResult(),
150  inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
151 }
152 
153 void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
154  SetIntRangeFn setResultRange) {
155  setResultRange(getResult(),
156  inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
157 }
158 
159 void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
160  SetIntRangeFn setResultRange) {
161  setResultRange(getResult(),
162  inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
163 }
164 
165 void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
166  SetIntRangeFn setResultRange) {
167  setResultRange(getResult(),
168  inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
169 }
170 
171 void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
172  SetIntRangeFn setResultRange) {
173  setResultRange(getResult(),
174  inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // Casts
179 //===----------------------------------------------------------------------===//
180 
182  unsigned srcWidth, unsigned destWidth,
183  bool isSigned) {
184  if (srcWidth < destWidth)
185  return isSigned ? extSIRange(range, destWidth)
186  : extUIRange(range, destWidth);
187  if (srcWidth > destWidth)
188  return truncRange(range, destWidth);
189  return range;
190 }
191 
192 // When casting to `index`, we will take the union of the possible fixed-width
193 // casts.
195  Type sourceType, Type destType,
196  bool isSigned) {
197  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
198  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
199  if (sourceType.isIndex())
200  return makeLikeDest(range, srcWidth, destWidth, isSigned);
201  // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
202  ConstantIntRanges storageRange =
203  makeLikeDest(range, srcWidth, destWidth, isSigned);
204  ConstantIntRanges minWidthRange =
205  makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
206  ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
207  ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
208  return ret;
209 }
210 
211 void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
212  SetIntRangeFn setResultRange) {
213  Type sourceType = getOperand().getType();
214  Type destType = getResult().getType();
215  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
216  /*isSigned=*/true));
217 }
218 
219 void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
220  SetIntRangeFn setResultRange) {
221  Type sourceType = getOperand().getType();
222  Type destType = getResult().getType();
223  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
224  /*isSigned=*/false));
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // CmpOp
229 //===----------------------------------------------------------------------===//
230 
231 void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
232  SetIntRangeFn setResultRange) {
233  index::IndexCmpPredicate indexPred = getPred();
234  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
235  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
236 
237  APInt min = APInt::getZero(1);
238  APInt max = APInt::getAllOnes(1);
239 
240  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
241 
242  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
243  rhsTrunc = truncRange(rhs, indexMinWidth);
244  std::optional<bool> truthValue32 =
245  intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
246 
247  if (truthValue64 == truthValue32) {
248  if (truthValue64.has_value() && *truthValue64)
249  min = max;
250  else if (truthValue64.has_value() && !(*truthValue64))
251  max = min;
252  }
253  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
258 //===----------------------------------------------------------------------===//
259 
260 void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
261  SetIntRangeFn setResultRange) {
262  unsigned storageWidth =
264  APInt min(/*numBits=*/storageWidth, indexMinWidth);
265  APInt max(/*numBits=*/storageWidth, indexMaxWidth);
266  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
267 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static std::function< ConstantIntRanges(ArrayRef< ConstantIntRanges >)> inferWithoutOverflowFlags(InferRangeWithOvfFlagsFn inferWithOvfFn)
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, unsigned srcWidth, unsigned destWidth, bool isSigned)
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, Type sourceType, Type destType, bool isSigned)
@ None
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferIndexOp(const InferRangeFn &inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305