MLIR  21.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/APSInt.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/MathExtras.h"
15 
16 namespace mlir {
17 
19  if (!v)
20  return false;
21  std::optional<int64_t> constint = getConstantIntValue(v);
22  if (!constint)
23  return false;
24  return *constint == 0;
25 }
26 
27 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
28  SmallVector<OpFoldResult>>
30  SmallVector<OpFoldResult> offsets, sizes, strides;
31  offsets.reserve(ranges.size());
32  sizes.reserve(ranges.size());
33  strides.reserve(ranges.size());
34  for (const auto &[offset, size, stride] : ranges) {
35  offsets.push_back(offset);
36  sizes.push_back(size);
37  strides.push_back(stride);
38  }
39  return std::make_tuple(offsets, sizes, strides);
40 }
41 
42 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
43 /// a) it is an IntegerAttr
44 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
45 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
47 /// come from an AttrSizedOperandSegments trait.
49  SmallVectorImpl<Value> &dynamicVec,
50  SmallVectorImpl<int64_t> &staticVec) {
51  auto v = llvm::dyn_cast_if_present<Value>(ofr);
52  if (!v) {
53  APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
54  staticVec.push_back(apInt.getSExtValue());
55  return;
56  }
57  dynamicVec.push_back(v);
58  staticVec.push_back(ShapedType::kDynamic);
59 }
60 
61 std::pair<int64_t, OpFoldResult>
63  int64_t tileSizeForShape =
64  getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
65 
66  OpFoldResult tileSizeOfrSimplified =
67  (tileSizeForShape != ShapedType::kDynamic)
68  ? b.getIndexAttr(tileSizeForShape)
69  : tileSizeOfr;
70 
71  return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
72  tileSizeOfrSimplified);
73 }
74 
76  SmallVectorImpl<Value> &dynamicVec,
77  SmallVectorImpl<int64_t> &staticVec) {
78  for (OpFoldResult ofr : ofrs)
79  dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
80 }
81 
82 /// Given a value, try to extract a constant Attribute. If this fails, return
83 /// the original value.
85  if (!val)
86  return OpFoldResult();
87  Attribute attr;
88  if (matchPattern(val, m_Constant(&attr)))
89  return attr;
90  return val;
91 }
92 
93 /// Given an array of values, try to extract a constant Attribute from each
94 /// value. If this fails, return the original value.
96  return llvm::to_vector(
97  llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
98 }
99 
100 /// Convert `arrayAttr` to a vector of OpFoldResult.
103  res.reserve(arrayAttr.size());
104  for (Attribute a : arrayAttr)
105  res.push_back(a);
106  return res;
107 }
108 
110  return IntegerAttr::get(IndexType::get(ctx), val);
111 }
112 
114  ArrayRef<int64_t> values) {
115  return llvm::to_vector(llvm::map_range(
116  values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
117 }
118 
119 /// If ofr is a constant integer or an IntegerAttr, return the integer.
120 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
121  // Case 1: Check for Constant integer.
122  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
123  APSInt intVal;
124  if (matchPattern(val, m_ConstantInt(&intVal)))
125  return intVal.getSExtValue();
126  return std::nullopt;
127  }
128  // Case 2: Check for IntegerAttr.
129  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
130  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
131  return intAttr.getValue().getSExtValue();
132  return std::nullopt;
133 }
134 
135 std::optional<SmallVector<int64_t>>
137  bool failed = false;
138  SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
139  auto cv = getConstantIntValue(ofr);
140  if (!cv.has_value())
141  failed = true;
142  return cv.value_or(0);
143  });
144  if (failed)
145  return std::nullopt;
146  return res;
147 }
148 
149 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
150  auto val = getConstantIntValue(ofr);
151  return val && *val == value;
152 }
153 
155  return llvm::all_of(
156  ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
157 }
158 
160  ArrayRef<int64_t> values) {
161  if (ofrs.size() != values.size())
162  return false;
163  std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
164  return constOfrs && llvm::equal(constOfrs.value(), values);
165 }
166 
167 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
168 /// or the same SSA value.
169 /// Ignore integer bitwidth and type mismatch that come from the fact there is
170 /// no IndexAttr and that IndexType has no bitwidth.
172  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
173  if (cst1 && cst2 && *cst1 == *cst2)
174  return true;
175  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
176  v2 = llvm::dyn_cast_if_present<Value>(ofr2);
177  return v1 && v1 == v2;
178 }
179 
181  ArrayRef<OpFoldResult> ofrs2) {
182  if (ofrs1.size() != ofrs2.size())
183  return false;
184  for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
185  if (!isEqualConstantIntOrValue(ofr1, ofr2))
186  return false;
187  return true;
188 }
189 
190 /// Return a vector of OpFoldResults with the same size a staticValues, but all
191 /// elements for which ShapedType::isDynamic is true, will be replaced by
192 /// dynamicValues.
194  ValueRange dynamicValues,
195  MLIRContext *context) {
197  res.reserve(staticValues.size());
198  unsigned numDynamic = 0;
199  unsigned count = static_cast<unsigned>(staticValues.size());
200  for (unsigned idx = 0; idx < count; ++idx) {
201  int64_t value = staticValues[idx];
202  res.push_back(ShapedType::isDynamic(value)
203  ? OpFoldResult{dynamicValues[numDynamic++]}
205  IntegerType::get(context, 64), staticValues[idx])});
206  }
207  return res;
208 }
210  ValueRange dynamicValues, Builder &b) {
211  return getMixedValues(staticValues, dynamicValues, b.getContext());
212 }
213 
214 /// Decompose a vector of mixed static or dynamic values into the corresponding
215 /// pair of arrays. This is the inverse function of `getMixedValues`.
216 std::pair<SmallVector<int64_t>, SmallVector<Value>>
218  SmallVector<int64_t> staticValues;
219  SmallVector<Value> dynamicValues;
220  for (const auto &it : mixedValues) {
221  if (auto attr = dyn_cast<Attribute>(it)) {
222  staticValues.push_back(cast<IntegerAttr>(attr).getInt());
223  } else {
224  staticValues.push_back(ShapedType::kDynamic);
225  dynamicValues.push_back(cast<Value>(it));
226  }
227  }
228  return {staticValues, dynamicValues};
229 }
230 
231 /// Helper to sort `values` according to matching `keys`.
232 template <typename K, typename V>
233 static SmallVector<V>
235  llvm::function_ref<bool(K, K)> compare) {
236  if (keys.empty())
237  return SmallVector<V>{values};
238  assert(keys.size() == values.size() && "unexpected mismatching sizes");
239  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
240  std::sort(indices.begin(), indices.end(),
241  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
242  SmallVector<V> res;
243  res.reserve(values.size());
244  for (int64_t i = 0, e = indices.size(); i < e; ++i)
245  res.push_back(values[indices[i]]);
246  return res;
247 }
248 
249 SmallVector<Value>
252  return getValuesSortedByKeyImpl(keys, values, compare);
253 }
254 
255 SmallVector<OpFoldResult>
258  return getValuesSortedByKeyImpl(keys, values, compare);
259 }
260 
261 SmallVector<int64_t>
264  return getValuesSortedByKeyImpl(keys, values, compare);
265 }
266 
267 /// Return the number of iterations for a loop with a lower bound `lb`, upper
268 /// bound `ub` and step `step`.
269 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
270  OpFoldResult step) {
271  if (lb == ub)
272  return 0;
273 
274  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
275  if (!lbConstant)
276  return std::nullopt;
277  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
278  if (!ubConstant)
279  return std::nullopt;
280  std::optional<int64_t> stepConstant = getConstantIntValue(step);
281  if (!stepConstant)
282  return std::nullopt;
283 
284  return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
285 }
286 
288  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
289  return !ShapedType::isDynamic(value) && value < 0;
290  });
291 }
292 
294  return llvm::none_of(strides, [](int64_t value) {
295  return !ShapedType::isDynamic(value) && value == 0;
296  });
297 }
298 
300  bool onlyNonNegative, bool onlyNonZero) {
301  bool valuesChanged = false;
302  for (OpFoldResult &ofr : ofrs) {
303  if (isa<Attribute>(ofr))
304  continue;
305  Attribute attr;
306  if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
307  // Note: All ofrs have index type.
308  if (onlyNonNegative && *getConstantIntValue(attr) < 0)
309  continue;
310  if (onlyNonZero && *getConstantIntValue(attr) == 0)
311  continue;
312  ofr = attr;
313  valuesChanged = true;
314  }
315  }
316  return success(valuesChanged);
317 }
318 
319 LogicalResult
321  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
322  /*onlyNonZero=*/false);
323 }
324 
326  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
327  /*onlyNonZero=*/true);
328 }
329 
330 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
MLIRContext * getContext() const
Definition: Builders.h:56
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.