MLIR 22.0.0git
InferIntRangeInterface.cpp
Go to the documentation of this file.
1//===- InferIntRangeInterface.cpp - Integer range inference interface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc"
13#include <optional>
14
15using namespace mlir;
16
18 return umin().getBitWidth() == other.umin().getBitWidth() &&
19 umin() == other.umin() && umax() == other.umax() &&
20 smin() == other.smin() && smax() == other.smax();
21}
22
23const APInt &ConstantIntRanges::umin() const { return uminVal; }
24
25const APInt &ConstantIntRanges::umax() const { return umaxVal; }
26
27const APInt &ConstantIntRanges::smin() const { return sminVal; }
28
29const APInt &ConstantIntRanges::smax() const { return smaxVal; }
30
32 type = getElementTypeOrSelf(type);
33 if (type.isIndex())
34 return IndexType::kInternalStorageBitWidth;
35 if (auto integerType = dyn_cast<IntegerType>(type))
36 return integerType.getWidth();
37 // Non-integer types have their bounds stored in width 0 `APInt`s.
38 return 0;
39}
40
42 return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
43}
44
46 return {value, value, value, value};
47}
48
50 bool isSigned) {
51 if (isSigned)
52 return fromSigned(min, max);
53 return fromUnsigned(min, max);
54}
55
57 const APInt &smax) {
58 unsigned int width = smin.getBitWidth();
59 APInt umin, umax;
60 if (smin.isNonNegative() == smax.isNonNegative()) {
61 umin = smin.ult(smax) ? smin : smax;
62 umax = smin.ugt(smax) ? smin : smax;
63 } else {
64 umin = APInt::getMinValue(width);
65 umax = APInt::getMaxValue(width);
66 }
67 return {umin, umax, smin, smax};
68}
69
71 const APInt &umax) {
72 unsigned int width = umin.getBitWidth();
73 APInt smin, smax;
74 if (umin.isNonNegative() == umax.isNonNegative()) {
75 smin = umin.slt(umax) ? umin : umax;
76 smax = umin.sgt(umax) ? umin : umax;
77 } else {
78 smin = APInt::getSignedMinValue(width);
79 smax = APInt::getSignedMaxValue(width);
80 }
81 return {umin, umax, smin, smax};
82}
83
86 // "Not an integer" poisons everything and also cannot be fed to comparison
87 // operators.
88 if (umin().getBitWidth() == 0)
89 return *this;
90 if (other.umin().getBitWidth() == 0)
91 return other;
92
93 const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
94 const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
95 const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
96 const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
97
98 return {uminUnion, umaxUnion, sminUnion, smaxUnion};
99}
100
103 // "Not an integer" poisons everything and also cannot be fed to comparison
104 // operators.
105 if (umin().getBitWidth() == 0)
106 return *this;
107 if (other.umin().getBitWidth() == 0)
108 return other;
109
110 const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
111 const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
112 const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
113 const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
114
115 return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
116}
117
118std::optional<APInt> ConstantIntRanges::getConstantValue() const {
119 // Note: we need to exclude the trivially-equal width 0 values here.
120 if (umin() == umax() && umin().getBitWidth() != 0)
121 return umin();
122 if (smin() == smax() && smin().getBitWidth() != 0)
123 return smin();
124 return std::nullopt;
125}
126
128 os << "unsigned : [";
129 range.umin().print(os, /*isSigned*/ false);
130 os << ", ";
131 range.umax().print(os, /*isSigned*/ false);
132 return os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
133}
134
136 unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
137 APInt umin = APInt::getMinValue(width);
138 APInt umax = APInt::getMaxValue(width);
139 APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
140 APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
141 return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}};
142}
143
145 range.print(os);
146 return os;
147}
148
151 GetIntRangeFn getIntRange, int32_t indexBitwidth) {
153 ranges.reserve(values.size());
154 for (OpFoldResult ofr : values) {
155 if (auto value = dyn_cast<Value>(ofr)) {
156 ranges.push_back(getIntRange(value));
157 continue;
158 }
159
160 // Create a constant range.
161 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
162 ranges.emplace_back(ConstantIntRanges::constant(
163 attr.getValue().sextOrTrunc(indexBitwidth)));
164 }
165 return ranges;
166}
167
169 InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
170 SetIntLatticeFn setResultRanges) {
172 unpacked.reserve(argRanges.size());
173
174 for (const IntegerValueRange &range : argRanges) {
175 if (range.isUninitialized())
176 return;
177 unpacked.push_back(range.getValue());
178 }
179
180 interface.inferResultRanges(
181 unpacked,
182 [&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
183 setResultRanges(value, IntegerValueRange{argRanges});
184 });
185}
186
188 InferIntRangeInterface interface, ArrayRef<ConstantIntRanges> argRanges,
189 SetIntRangeFn setResultRanges) {
190 auto ranges = llvm::to_vector_of<IntegerValueRange>(argRanges);
191 interface.inferResultRangesFromOptional(
192 ranges,
193 [&setResultRanges](Value value, const IntegerValueRange &argRanges) {
194 if (!argRanges.isUninitialized())
195 setResultRanges(value, argRanges.getValue());
196 });
197}
static unsigned getBitWidth(Type type)
Definition Pattern.cpp:385
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges maxRange(unsigned bitwidth)
Create a ConstantIntRanges with the maximum bounds for the width bitwidth, that is - [0,...
ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin, const APInt &smax)
Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
static ConstantIntRanges range(const APInt &min, const APInt &max, bool isSigned)
Create a ConstantIntRanges whose minimum is min and maximum is max with isSigned specifying if the mi...
ConstantIntRanges intersection(const ConstantIntRanges &other) const
Returns the intersection (computed separately for signed and unsigned bounds) of this range and other...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const
Returns the union (computed separately for signed and unsigned bounds) of this range and other.
bool operator==(const ConstantIntRanges &other) const
This lattice value represents the integer range of an SSA value.
IntegerValueRange(ConstantIntRanges value)
Create an integer value range lattice value.
void print(raw_ostream &os) const
Print the integer value range.
static IntegerValueRange getMaxRange(Value value)
Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) range that is used to mark the v...
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void defaultInferResultRanges(InferIntRangeInterface interface, ArrayRef< IntegerValueRange > argRanges, SetIntLatticeFn setResultRanges)
Default implementation of inferResultRanges which dispatches to the inferResultRangesFromOptional.
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface, ArrayRef< ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
Default implementation of inferResultRangesFromOptional which dispatches to the inferResultRanges.
Include the generated interface declarations.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.