MLIR  19.0.0git
InferIntRangeInterfaceImpls.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 #include "llvm/Support/Debug.h"
14 #include <optional>
15 
16 #define DEBUG_TYPE "int-range-analysis"
17 
18 using namespace mlir;
19 using namespace mlir::index;
20 using namespace mlir::intrange;
21 
22 //===----------------------------------------------------------------------===//
23 // Constants
24 //===----------------------------------------------------------------------===//
25 
26 void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
27  SetIntRangeFn setResultRange) {
28  const APInt &value = getValue();
29  setResultRange(getResult(), ConstantIntRanges::constant(value));
30 }
31 
32 void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
33  SetIntRangeFn setResultRange) {
34  bool value = getValue();
35  APInt asInt(/*numBits=*/1, value);
36  setResultRange(getResult(), ConstantIntRanges::constant(asInt));
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // Arithmec operations. All of these operations will have their results inferred
41 // using both the 64-bit values and truncated 32-bit values of their inputs,
42 // with the results being the union of those inferences, except where the
43 // truncation of the 64-bit result is equal to the 32-bit result (at which time
44 // we take the 64-bit result).
45 //===----------------------------------------------------------------------===//
46 
47 void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
48  SetIntRangeFn setResultRange) {
49  setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
50 }
51 
52 void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
53  SetIntRangeFn setResultRange) {
54  setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
55 }
56 
57 void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
58  SetIntRangeFn setResultRange) {
59  setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
60 }
61 
62 void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
63  SetIntRangeFn setResultRange) {
64  setResultRange(getResult(),
65  inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
66 }
67 
68 void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
69  SetIntRangeFn setResultRange) {
70  setResultRange(getResult(),
72 }
73 
74 void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
75  SetIntRangeFn setResultRange) {
76  setResultRange(getResult(),
77  inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
78 }
79 
80 void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
81  SetIntRangeFn setResultRange) {
82  setResultRange(getResult(),
84 }
85 
86 void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
87  SetIntRangeFn setResultRange) {
88  return setResultRange(
89  getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
90 }
91 
92 void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
93  SetIntRangeFn setResultRange) {
94  setResultRange(getResult(),
96 }
97 
98 void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
99  SetIntRangeFn setResultRange) {
100  setResultRange(getResult(),
101  inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
102 }
103 
104 void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
105  SetIntRangeFn setResultRange) {
106  setResultRange(getResult(),
107  inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
108 }
109 
110 void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
111  SetIntRangeFn setResultRange) {
112  setResultRange(getResult(),
113  inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
114 }
115 
116 void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
117  SetIntRangeFn setResultRange) {
118  setResultRange(getResult(),
119  inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
120 }
121 
122 void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
123  SetIntRangeFn setResultRange) {
124  setResultRange(getResult(),
125  inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
126 }
127 
128 void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129  SetIntRangeFn setResultRange) {
130  setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
131 }
132 
133 void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
134  SetIntRangeFn setResultRange) {
135  setResultRange(getResult(),
136  inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
137 }
138 
139 void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
140  SetIntRangeFn setResultRange) {
141  setResultRange(getResult(),
142  inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
143 }
144 
145 void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
146  SetIntRangeFn setResultRange) {
147  setResultRange(getResult(),
148  inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
149 }
150 
151 void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
152  SetIntRangeFn setResultRange) {
153  setResultRange(getResult(),
154  inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
155 }
156 
157 void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
158  SetIntRangeFn setResultRange) {
159  setResultRange(getResult(),
160  inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // Casts
165 //===----------------------------------------------------------------------===//
166 
168  unsigned srcWidth, unsigned destWidth,
169  bool isSigned) {
170  if (srcWidth < destWidth)
171  return isSigned ? extSIRange(range, destWidth)
172  : extUIRange(range, destWidth);
173  if (srcWidth > destWidth)
174  return truncRange(range, destWidth);
175  return range;
176 }
177 
178 // When casting to `index`, we will take the union of the possible fixed-width
179 // casts.
181  Type sourceType, Type destType,
182  bool isSigned) {
183  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
184  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
185  if (sourceType.isIndex())
186  return makeLikeDest(range, srcWidth, destWidth, isSigned);
187  // We are casting to indexs, so use the union of the 32-bit and 64-bit casts
188  ConstantIntRanges storageRange =
189  makeLikeDest(range, srcWidth, destWidth, isSigned);
190  ConstantIntRanges minWidthRange =
191  makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
192  ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
193  ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
194  return ret;
195 }
196 
197 void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
198  SetIntRangeFn setResultRange) {
199  Type sourceType = getOperand().getType();
200  Type destType = getResult().getType();
201  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
202  /*isSigned=*/true));
203 }
204 
205 void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
206  SetIntRangeFn setResultRange) {
207  Type sourceType = getOperand().getType();
208  Type destType = getResult().getType();
209  setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
210  /*isSigned=*/false));
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // CmpOp
215 //===----------------------------------------------------------------------===//
216 
217 void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
218  SetIntRangeFn setResultRange) {
219  index::IndexCmpPredicate indexPred = getPred();
220  intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
221  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
222 
223  APInt min = APInt::getZero(1);
224  APInt max = APInt::getAllOnes(1);
225 
226  std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
227 
228  ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
229  rhsTrunc = truncRange(rhs, indexMinWidth);
230  std::optional<bool> truthValue32 =
231  intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
232 
233  if (truthValue64 == truthValue32) {
234  if (truthValue64.has_value() && *truthValue64)
235  min = max;
236  else if (truthValue64.has_value() && !(*truthValue64))
237  max = min;
238  }
239  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // SizeOf, which is bounded between the two supported bitwidth (32 and 64).
244 //===----------------------------------------------------------------------===//
245 
246 void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
247  SetIntRangeFn setResultRange) {
248  unsigned storageWidth =
249  ConstantIntRanges::getStorageBitwidth(getResult().getType());
250  APInt min(/*numBits=*/storageWidth, indexMinWidth);
251  APInt max(/*numBits=*/storageWidth, indexMaxWidth);
252  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
253 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, unsigned srcWidth, unsigned destWidth, bool isSigned)
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, Type sourceType, Type destType, bool isSigned)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
ConstantIntRanges inferAnd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferSub(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferIndexOp(InferRangeFn inferFn, ArrayRef< ConstantIntRanges > argRanges, CmpMode mode)
Compute inferFn on ranges, whose size should be the index storage bitwidth.
ConstantIntRanges extSIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the signed values in range to sign-extend it to destWidth.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
ConstantIntRanges inferMinS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges)
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
ConstantIntRanges inferOr(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth)
Independently zero-extend the unsigned values and sign-extend the signed values in range to destWidth...
ConstantIntRanges truncRange(const ConstantIntRanges &range, unsigned destWidth)
Truncate range to destWidth bits, taking care to handle cases such as the truncation of [255,...
ConstantIntRanges inferDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferCeilDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMinU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferXor(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferDivS(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferShrU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferRemU(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMinWidth
ConstantIntRanges inferFloorDivS(ArrayRef< ConstantIntRanges > argRanges)
static constexpr unsigned indexMaxWidth
ConstantIntRanges inferCeilDivU(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges extUIRange(const ConstantIntRanges &range, unsigned destWidth)
Use the unsigned values in range to zero-extend it to destWidth.
ConstantIntRanges inferShl(ArrayRef< ConstantIntRanges > argRanges)
ConstantIntRanges inferMaxS(ArrayRef< ConstantIntRanges > argRanges)
Include the generated interface declarations.