MLIR  19.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/APSInt.h"
13 #include "llvm/Support/MathExtras.h"
14 
15 namespace mlir {
16 
18  if (!v)
19  return false;
20  std::optional<int64_t> constint = getConstantIntValue(v);
21  if (!constint)
22  return false;
23  return *constint == 0;
24 }
25 
26 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
27  SmallVector<OpFoldResult>>
29  SmallVector<OpFoldResult> offsets, sizes, strides;
30  offsets.reserve(ranges.size());
31  sizes.reserve(ranges.size());
32  strides.reserve(ranges.size());
33  for (const auto &[offset, size, stride] : ranges) {
34  offsets.push_back(offset);
35  sizes.push_back(size);
36  strides.push_back(stride);
37  }
38  return std::make_tuple(offsets, sizes, strides);
39 }
40 
41 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
42 /// a) it is an IntegerAttr
43 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
44 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
46 /// come from an AttrSizedOperandSegments trait.
48  SmallVectorImpl<Value> &dynamicVec,
49  SmallVectorImpl<int64_t> &staticVec) {
50  auto v = llvm::dyn_cast_if_present<Value>(ofr);
51  if (!v) {
52  APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
53  staticVec.push_back(apInt.getSExtValue());
54  return;
55  }
56  dynamicVec.push_back(v);
57  staticVec.push_back(ShapedType::kDynamic);
58 }
59 
61  SmallVectorImpl<Value> &dynamicVec,
62  SmallVectorImpl<int64_t> &staticVec) {
63  for (OpFoldResult ofr : ofrs)
64  dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
65 }
66 
67 /// Given a value, try to extract a constant Attribute. If this fails, return
68 /// the original value.
70  if (!val)
71  return OpFoldResult();
72  Attribute attr;
73  if (matchPattern(val, m_Constant(&attr)))
74  return attr;
75  return val;
76 }
77 
78 /// Given an array of values, try to extract a constant Attribute from each
79 /// value. If this fails, return the original value.
81  return llvm::to_vector(
82  llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
83 }
84 
85 /// Convert `arrayAttr` to a vector of OpFoldResult.
88  res.reserve(arrayAttr.size());
89  for (Attribute a : arrayAttr)
90  res.push_back(a);
91  return res;
92 }
93 
95  return IntegerAttr::get(IndexType::get(ctx), val);
96 }
97 
99  ArrayRef<int64_t> values) {
100  return llvm::to_vector(llvm::map_range(
101  values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
102 }
103 
104 /// If ofr is a constant integer or an IntegerAttr, return the integer.
105 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
106  // Case 1: Check for Constant integer.
107  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
108  APSInt intVal;
109  if (matchPattern(val, m_ConstantInt(&intVal)))
110  return intVal.getSExtValue();
111  return std::nullopt;
112  }
113  // Case 2: Check for IntegerAttr.
114  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
115  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
116  return intAttr.getValue().getSExtValue();
117  return std::nullopt;
118 }
119 
120 std::optional<SmallVector<int64_t>>
122  bool failed = false;
123  SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
124  auto cv = getConstantIntValue(ofr);
125  if (!cv.has_value())
126  failed = true;
127  return cv.has_value() ? cv.value() : 0;
128  });
129  if (failed)
130  return std::nullopt;
131  return res;
132 }
133 
134 /// Return true if `ofr` is constant integer equal to `value`.
135 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
136  auto val = getConstantIntValue(ofr);
137  return val && *val == value;
138 }
139 
140 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
141 /// or the same SSA value.
142 /// Ignore integer bitwidth and type mismatch that come from the fact there is
143 /// no IndexAttr and that IndexType has no bitwidth.
145  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
146  if (cst1 && cst2 && *cst1 == *cst2)
147  return true;
148  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
149  v2 = llvm::dyn_cast_if_present<Value>(ofr2);
150  return v1 && v1 == v2;
151 }
152 
154  ArrayRef<OpFoldResult> ofrs2) {
155  if (ofrs1.size() != ofrs2.size())
156  return false;
157  for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
158  if (!isEqualConstantIntOrValue(ofr1, ofr2))
159  return false;
160  return true;
161 }
162 
163 /// Return a vector of OpFoldResults with the same size a staticValues, but all
164 /// elements for which ShapedType::isDynamic is true, will be replaced by
165 /// dynamicValues.
167  ValueRange dynamicValues, Builder &b) {
169  res.reserve(staticValues.size());
170  unsigned numDynamic = 0;
171  unsigned count = static_cast<unsigned>(staticValues.size());
172  for (unsigned idx = 0; idx < count; ++idx) {
173  int64_t value = staticValues[idx];
174  res.push_back(ShapedType::isDynamic(value)
175  ? OpFoldResult{dynamicValues[numDynamic++]}
176  : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
177  }
178  return res;
179 }
180 
181 /// Decompose a vector of mixed static or dynamic values into the corresponding
182 /// pair of arrays. This is the inverse function of `getMixedValues`.
183 std::pair<SmallVector<int64_t>, SmallVector<Value>>
185  SmallVector<int64_t> staticValues;
186  SmallVector<Value> dynamicValues;
187  for (const auto &it : mixedValues) {
188  if (it.is<Attribute>()) {
189  staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
190  } else {
191  staticValues.push_back(ShapedType::kDynamic);
192  dynamicValues.push_back(it.get<Value>());
193  }
194  }
195  return {staticValues, dynamicValues};
196 }
197 
198 /// Helper to sort `values` according to matching `keys`.
199 template <typename K, typename V>
200 static SmallVector<V>
202  llvm::function_ref<bool(K, K)> compare) {
203  if (keys.empty())
204  return SmallVector<V>{values};
205  assert(keys.size() == values.size() && "unexpected mismatching sizes");
206  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
207  std::sort(indices.begin(), indices.end(),
208  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
209  SmallVector<V> res;
210  res.reserve(values.size());
211  for (int64_t i = 0, e = indices.size(); i < e; ++i)
212  res.push_back(values[indices[i]]);
213  return res;
214 }
215 
216 SmallVector<Value>
219  return getValuesSortedByKeyImpl(keys, values, compare);
220 }
221 
222 SmallVector<OpFoldResult>
225  return getValuesSortedByKeyImpl(keys, values, compare);
226 }
227 
228 SmallVector<int64_t>
231  return getValuesSortedByKeyImpl(keys, values, compare);
232 }
233 
234 /// Return the number of iterations for a loop with a lower bound `lb`, upper
235 /// bound `ub` and step `step`.
236 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
237  OpFoldResult step) {
238  if (lb == ub)
239  return 0;
240 
241  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
242  if (!lbConstant)
243  return std::nullopt;
244  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
245  if (!ubConstant)
246  return std::nullopt;
247  std::optional<int64_t> stepConstant = getConstantIntValue(step);
248  if (!stepConstant)
249  return std::nullopt;
250 
251  return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
252 }
253 
255  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
256  return !ShapedType::isDynamic(value) && value < 0;
257  });
258 }
259 
261  return llvm::none_of(strides, [](int64_t value) {
262  return !ShapedType::isDynamic(value) && value == 0;
263  });
264 }
265 
267  bool onlyNonNegative, bool onlyNonZero) {
268  bool valuesChanged = false;
269  for (OpFoldResult &ofr : ofrs) {
270  if (ofr.is<Attribute>())
271  continue;
272  Attribute attr;
273  if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
274  // Note: All ofrs have index type.
275  if (onlyNonNegative && *getConstantIntValue(attr) < 0)
276  continue;
277  if (onlyNonZero && *getConstantIntValue(attr) == 0)
278  continue;
279  ofr = attr;
280  valuesChanged = true;
281  }
282  }
283  return success(valuesChanged);
284 }
285 
286 LogicalResult
288  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
289  /*onlyNonZero=*/false);
290 }
291 
293  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
294  /*onlyNonZero=*/true);
295 }
296 
297 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:67
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.