MLIR  22.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Support/LLVM.h"
13 #include "llvm/ADT/APSInt.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Support/DebugLog.h"
16 #include "llvm/Support/MathExtras.h"
17 
18 namespace mlir {
19 
21 
22 bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); }
23 
24 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
25  SmallVector<OpFoldResult>>
27  SmallVector<OpFoldResult> offsets, sizes, strides;
28  offsets.reserve(ranges.size());
29  sizes.reserve(ranges.size());
30  strides.reserve(ranges.size());
31  for (const auto &[offset, size, stride] : ranges) {
32  offsets.push_back(offset);
33  sizes.push_back(size);
34  strides.push_back(stride);
35  }
36  return std::make_tuple(offsets, sizes, strides);
37 }
38 
39 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
40 /// a) it is an IntegerAttr
41 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
42 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
43 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
44 /// come from an AttrSizedOperandSegments trait.
46  SmallVectorImpl<Value> &dynamicVec,
47  SmallVectorImpl<int64_t> &staticVec) {
48  auto v = llvm::dyn_cast_if_present<Value>(ofr);
49  if (!v) {
50  APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
51  staticVec.push_back(apInt.getSExtValue());
52  return;
53  }
54  dynamicVec.push_back(v);
55  staticVec.push_back(ShapedType::kDynamic);
56 }
57 
58 std::pair<int64_t, OpFoldResult>
60  int64_t tileSizeForShape =
61  getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
62 
63  OpFoldResult tileSizeOfrSimplified =
64  (tileSizeForShape != ShapedType::kDynamic)
65  ? b.getIndexAttr(tileSizeForShape)
66  : tileSizeOfr;
67 
68  return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
69  tileSizeOfrSimplified);
70 }
71 
73  SmallVectorImpl<Value> &dynamicVec,
74  SmallVectorImpl<int64_t> &staticVec) {
75  for (OpFoldResult ofr : ofrs)
76  dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
77 }
78 
79 /// Given a value, try to extract a constant Attribute. If this fails, return
80 /// the original value.
82  if (!val)
83  return OpFoldResult();
84  Attribute attr;
85  if (matchPattern(val, m_Constant(&attr)))
86  return attr;
87  return val;
88 }
89 
90 /// Given an array of values, try to extract a constant Attribute from each
91 /// value. If this fails, return the original value.
93  return llvm::to_vector(
94  llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
95 }
96 
97 /// Convert `arrayAttr` to a vector of OpFoldResult.
100  res.reserve(arrayAttr.size());
101  for (Attribute a : arrayAttr)
102  res.push_back(a);
103  return res;
104 }
105 
107  return IntegerAttr::get(IndexType::get(ctx), val);
108 }
109 
111  ArrayRef<int64_t> values) {
112  return llvm::to_vector(llvm::map_range(
113  values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
114 }
115 
116 /// If ofr is a constant integer or an IntegerAttr, return the integer.
117 /// The boolean indicates whether the value is an index type.
118 std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
119  // Case 1: Check for Constant integer.
120  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
121  APInt intVal;
122  if (matchPattern(val, m_ConstantInt(&intVal)))
123  return std::make_pair(intVal, val.getType().isIndex());
124  return std::nullopt;
125  }
126  // Case 2: Check for IntegerAttr.
127  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
128  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
129  return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
130  return std::nullopt;
131 }
132 
133 /// If ofr is a constant integer or an IntegerAttr, return the integer.
134 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
135  std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
136  if (!apInt)
137  return std::nullopt;
138  return apInt->first.getSExtValue();
139 }
140 
141 std::optional<SmallVector<int64_t>>
143  bool failed = false;
144  SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
145  auto cv = getConstantIntValue(ofr);
146  if (!cv.has_value())
147  failed = true;
148  return cv.value_or(0);
149  });
150  if (failed)
151  return std::nullopt;
152  return res;
153 }
154 
155 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
156  return getConstantIntValue(ofr) == value;
157 }
158 
160  return llvm::all_of(
161  ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
162 }
163 
165  ArrayRef<int64_t> values) {
166  if (ofrs.size() != values.size())
167  return false;
168  std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
169  return constOfrs && llvm::equal(constOfrs.value(), values);
170 }
171 
172 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
173 /// or the same SSA value.
174 /// Ignore integer bitwidth and type mismatch that come from the fact there is
175 /// no IndexAttr and that IndexType has no bitwidth.
177  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
178  if (cst1 && cst2 && *cst1 == *cst2)
179  return true;
180  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
181  v2 = llvm::dyn_cast_if_present<Value>(ofr2);
182  return v1 && v1 == v2;
183 }
184 
186  ArrayRef<OpFoldResult> ofrs2) {
187  if (ofrs1.size() != ofrs2.size())
188  return false;
189  for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
190  if (!isEqualConstantIntOrValue(ofr1, ofr2))
191  return false;
192  return true;
193 }
194 
195 /// Return a vector of OpFoldResults with the same size as staticValues, but all
196 /// elements for which ShapedType::isDynamic is true, will be replaced by
197 /// dynamicValues.
199  ValueRange dynamicValues,
200  MLIRContext *context) {
201  assert(dynamicValues.size() == static_cast<size_t>(llvm::count_if(
202  staticValues, ShapedType::isDynamic)) &&
203  "expected the rank of dynamic values to match the number of "
204  "values known to be dynamic");
206  res.reserve(staticValues.size());
207  unsigned numDynamic = 0;
208  unsigned count = static_cast<unsigned>(staticValues.size());
209  for (unsigned idx = 0; idx < count; ++idx) {
210  int64_t value = staticValues[idx];
211  res.push_back(ShapedType::isDynamic(value)
212  ? OpFoldResult{dynamicValues[numDynamic++]}
214  IntegerType::get(context, 64), staticValues[idx])});
215  }
216  return res;
217 }
219  ValueRange dynamicValues, Builder &b) {
220  return getMixedValues(staticValues, dynamicValues, b.getContext());
221 }
222 
223 /// Decompose a vector of mixed static or dynamic values into the corresponding
224 /// pair of arrays. This is the inverse function of `getMixedValues`.
225 std::pair<SmallVector<int64_t>, SmallVector<Value>>
227  SmallVector<int64_t> staticValues;
228  SmallVector<Value> dynamicValues;
229  for (const auto &it : mixedValues) {
230  if (auto attr = dyn_cast<Attribute>(it)) {
231  staticValues.push_back(cast<IntegerAttr>(attr).getInt());
232  } else {
233  staticValues.push_back(ShapedType::kDynamic);
234  dynamicValues.push_back(cast<Value>(it));
235  }
236  }
237  return {staticValues, dynamicValues};
238 }
239 
240 /// Helper to sort `values` according to matching `keys`.
241 template <typename K, typename V>
242 static SmallVector<V>
244  llvm::function_ref<bool(K, K)> compare) {
245  if (keys.empty())
246  return SmallVector<V>{values};
247  assert(keys.size() == values.size() && "unexpected mismatching sizes");
248  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
249  llvm::sort(indices,
250  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
251  SmallVector<V> res;
252  res.reserve(values.size());
253  for (int64_t i = 0, e = indices.size(); i < e; ++i)
254  res.push_back(values[indices[i]]);
255  return res;
256 }
257 
258 SmallVector<Value>
261  return getValuesSortedByKeyImpl(keys, values, compare);
262 }
263 
264 SmallVector<OpFoldResult>
267  return getValuesSortedByKeyImpl(keys, values, compare);
268 }
269 
270 SmallVector<int64_t>
273  return getValuesSortedByKeyImpl(keys, values, compare);
274 }
275 
276 /// Return the number of iterations for a loop with a lower bound `lb`, upper
277 /// bound `ub` and step `step`.
278 std::optional<APInt> constantTripCount(
279  OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
280  llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
282  // This is the bitwidth used to return 0 when loop does not execute.
283  // We infer it from the type of the bound if it isn't an index type.
284  auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
285  if (auto intAttr =
286  dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
287  if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
288  return std::make_tuple(intType.getWidth(), intType.isIndex());
289  } else {
290  auto val = cast<Value>(ofr);
291  if (auto intType = dyn_cast<IntegerType>(val.getType()))
292  return std::make_tuple(intType.getWidth(), intType.isIndex());
293  }
294  return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
295  };
296  auto [bitwidth, isIndex] = getBitwidth(lb);
297  // This would better be an assert, but unfortunately it breaks scf.for_all
298  // which is missing attributes and SSA value optionally for its bounds, and
299  // uses Index type for the dynamic bounds but i64 for the static bounds. This
300  // is broken...
301  if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
302  LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
303  << lb;
304  return std::nullopt;
305  }
306  if (lb == ub)
307  return APInt(bitwidth, 0);
308 
309  std::optional<std::pair<APInt, bool>> maybeStepCst =
310  getConstantAPIntValue(step);
311 
312  if (maybeStepCst) {
313  auto &stepCst = maybeStepCst->first;
314  assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
315  "step must have the same bitwidth as lb and ub");
316  if (stepCst.isZero())
317  return stepCst;
318  if (stepCst.isNegative())
319  return APInt(bitwidth, 0);
320  }
321 
322  if (isIndex) {
323  LDBG()
324  << "Computing loop trip count for index type may break with overflow";
325  // TODO: we can't compute the trip count for index type. We should fix this
326  // but too many tests are failing right now.
327  // return {};
328  }
329 
330  /// Compute the difference between the upper and lower bound: either from the
331  /// constant value or using the computeUbMinusLb callback.
332  llvm::APSInt diff;
333  std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
334  std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
335  if (maybeLbCst) {
336  // If one of the bounds is not a constant, we can't compute the trip count.
337  if (!maybeUbCst)
338  return std::nullopt;
339  APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
340  APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
341  if (!maybeUbCst)
342  return std::nullopt;
343  if (ubCst <= lbCst) {
344  LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
345  << lbCst.getBitWidth() << ") <= " << ubCst << "("
346  << ubCst.getBitWidth() << "), "
347  << (isSigned ? "isSigned" : "isUnsigned") << ")";
348  return APInt(bitwidth, 0);
349  }
350  diff = ubCst - lbCst;
351  } else {
352  if (maybeUbCst)
353  return std::nullopt;
354 
355  /// Non-constant bound, let's try to compute the difference between the
356  /// upper and lower bound
357  std::optional<llvm::APSInt> maybeDiff =
358  computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
359  if (!maybeDiff)
360  return std::nullopt;
361  diff = *maybeDiff;
362  }
363  LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
364  << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
365  if (diff.isNegative()) {
366  LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
367  return APInt(bitwidth, 0);
368  }
369  if (!maybeStepCst) {
370  LDBG()
371  << "constantTripCount can't be computed because step is not a constant";
372  return std::nullopt;
373  }
374  auto &stepCst = maybeStepCst->first;
375  llvm::APInt tripCount = diff.sdiv(stepCst);
376  llvm::APInt r = diff.srem(stepCst);
377  if (!r.isZero())
378  tripCount = tripCount + 1;
379  LDBG() << "constantTripCount found: " << tripCount;
380  return tripCount;
381 }
382 
384  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
385  return ShapedType::isStatic(value) && value < 0;
386  });
387 }
388 
390  return llvm::none_of(strides, [](int64_t value) {
391  return ShapedType::isStatic(value) && value == 0;
392  });
393 }
394 
396  bool onlyNonNegative, bool onlyNonZero) {
397  bool valuesChanged = false;
398  for (OpFoldResult &ofr : ofrs) {
399  if (isa<Attribute>(ofr))
400  continue;
401  Attribute attr;
402  if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
403  // Note: All ofrs have index type.
404  if (onlyNonNegative && *getConstantIntValue(attr) < 0)
405  continue;
406  if (onlyNonZero && *getConstantIntValue(attr) == 0)
407  continue;
408  ofr = attr;
409  valuesChanged = true;
410  }
411  }
412  return success(valuesChanged);
413 }
414 
415 LogicalResult
417  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
418  /*onlyNonZero=*/false);
419 }
420 
422  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
423  /*onlyNonZero=*/true);
424 }
425 
426 } // namespace mlir
static std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition: SCF.cpp:115
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:56
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.