MLIR  21.0.0git
InferIntRangeInterface.cpp
Go to the documentation of this file.
1 //===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
13 #include <optional>
14 
15 using namespace mlir;
16 
18  return umin().getBitWidth() == other.umin().getBitWidth() &&
19  umin() == other.umin() && umax() == other.umax() &&
20  smin() == other.smin() && smax() == other.smax();
21 }
22 
23 const APInt &ConstantIntRanges::umin() const { return uminVal; }
24 
25 const APInt &ConstantIntRanges::umax() const { return umaxVal; }
26 
27 const APInt &ConstantIntRanges::smin() const { return sminVal; }
28 
29 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
30 
32  type = getElementTypeOrSelf(type);
33  if (type.isIndex())
34  return IndexType::kInternalStorageBitWidth;
35  if (auto integerType = dyn_cast<IntegerType>(type))
36  return integerType.getWidth();
37  // Non-integer types have their bounds stored in width 0 `APInt`s.
38  return 0;
39 }
40 
42  return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
43 }
44 
46  return {value, value, value, value};
47 }
48 
50  bool isSigned) {
51  if (isSigned)
52  return fromSigned(min, max);
53  return fromUnsigned(min, max);
54 }
55 
57  const APInt &smax) {
58  unsigned int width = smin.getBitWidth();
59  APInt umin, umax;
60  if (smin.isNonNegative() == smax.isNonNegative()) {
61  umin = smin.ult(smax) ? smin : smax;
62  umax = smin.ugt(smax) ? smin : smax;
63  } else {
64  umin = APInt::getMinValue(width);
65  umax = APInt::getMaxValue(width);
66  }
67  return {umin, umax, smin, smax};
68 }
69 
71  const APInt &umax) {
72  unsigned int width = umin.getBitWidth();
73  APInt smin, smax;
74  if (umin.isNonNegative() == umax.isNonNegative()) {
75  smin = umin.slt(umax) ? umin : umax;
76  smax = umin.sgt(umax) ? umin : umax;
77  } else {
78  smin = APInt::getSignedMinValue(width);
79  smax = APInt::getSignedMaxValue(width);
80  }
81  return {umin, umax, smin, smax};
82 }
83 
86  // "Not an integer" poisons everything and also cannot be fed to comparison
87  // operators.
88  if (umin().getBitWidth() == 0)
89  return *this;
90  if (other.umin().getBitWidth() == 0)
91  return other;
92 
93  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
94  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
95  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
96  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
97 
98  return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99 }
100 
103  // "Not an integer" poisons everything and also cannot be fed to comparison
104  // operators.
105  if (umin().getBitWidth() == 0)
106  return *this;
107  if (other.umin().getBitWidth() == 0)
108  return other;
109 
110  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
111  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
112  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
113  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
114 
115  return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116 }
117 
118 std::optional<APInt> ConstantIntRanges::getConstantValue() const {
119  // Note: we need to exclude the trivially-equal width 0 values here.
120  if (umin() == umax() && umin().getBitWidth() != 0)
121  return umin();
122  if (smin() == smax() && smin().getBitWidth() != 0)
123  return smin();
124  return std::nullopt;
125 }
126 
127 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
128  os << "unsigned : [";
129  range.umin().print(os, /*isSigned*/ false);
130  os << ", ";
131  range.umax().print(os, /*isSigned*/ false);
132  return os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
133 }
134 
136  unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
137  if (width == 0)
138  return {};
139 
140  APInt umin = APInt::getMinValue(width);
141  APInt umax = APInt::getMaxValue(width);
142  APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
143  APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
144  return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
145 }
146 
147 raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
148  range.print(os);
149  return os;
150 }
151 
153  InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
154  SetIntLatticeFn setResultRanges) {
156  unpacked.reserve(argRanges.size());
157 
158  for (const IntegerValueRange &range : argRanges) {
159  if (range.isUninitialized())
160  return;
161  unpacked.push_back(range.getValue());
162  }
163 
164  interface.inferResultRanges(
165  unpacked,
166  [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
167  setResultRanges(value, IntegerValueRange{argRanges});
168  });
169 }
170 
172  InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
173  SetIntRangeFn setResultRanges) {
174  auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
175  interface.inferResultRangesFromOptional(
176  ranges,
177  [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
178  if (!argRanges.isUninitialized())
179  setResultRanges(value, argRanges.getValue());
180  });
181 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:67
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
bool operator==(const ConstantIntRanges &other) const
This lattice value represents the integer range of an SSA value.
void print(raw_ostream &os) const
Print the integer value range.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void defaultInferResultRanges(InferIntRangeInterface interface, ArrayRef< IntegerValueRange > argRanges, SetIntLatticeFn setResultRanges)
Default implementation of inferResultRanges which dispatches to the inferResultRangesFromOptional.
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface, ArrayRef< ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Default implementation of inferResultRangesFromOptional which dispatches to the inferResultRanges.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78