MLIR  22.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/APSInt.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/MathExtras.h"
15 
16 namespace mlir {
17 
19 
20 bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); }
21 
22 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
23  SmallVector<OpFoldResult>>
25  SmallVector<OpFoldResult> offsets, sizes, strides;
26  offsets.reserve(ranges.size());
27  sizes.reserve(ranges.size());
28  strides.reserve(ranges.size());
29  for (const auto &[offset, size, stride] : ranges) {
30  offsets.push_back(offset);
31  sizes.push_back(size);
32  strides.push_back(stride);
33  }
34  return std::make_tuple(offsets, sizes, strides);
35 }
36 
37 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
38 /// a) it is an IntegerAttr
39 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
40 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
41 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
42 /// come from an AttrSizedOperandSegments trait.
44  SmallVectorImpl<Value> &dynamicVec,
45  SmallVectorImpl<int64_t> &staticVec) {
46  auto v = llvm::dyn_cast_if_present<Value>(ofr);
47  if (!v) {
48  APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
49  staticVec.push_back(apInt.getSExtValue());
50  return;
51  }
52  dynamicVec.push_back(v);
53  staticVec.push_back(ShapedType::kDynamic);
54 }
55 
56 std::pair<int64_t, OpFoldResult>
58  int64_t tileSizeForShape =
59  getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
60 
61  OpFoldResult tileSizeOfrSimplified =
62  (tileSizeForShape != ShapedType::kDynamic)
63  ? b.getIndexAttr(tileSizeForShape)
64  : tileSizeOfr;
65 
66  return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
67  tileSizeOfrSimplified);
68 }
69 
71  SmallVectorImpl<Value> &dynamicVec,
72  SmallVectorImpl<int64_t> &staticVec) {
73  for (OpFoldResult ofr : ofrs)
74  dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
75 }
76 
77 /// Given a value, try to extract a constant Attribute. If this fails, return
78 /// the original value.
80  if (!val)
81  return OpFoldResult();
82  Attribute attr;
83  if (matchPattern(val, m_Constant(&attr)))
84  return attr;
85  return val;
86 }
87 
88 /// Given an array of values, try to extract a constant Attribute from each
89 /// value. If this fails, return the original value.
91  return llvm::to_vector(
92  llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
93 }
94 
95 /// Convert `arrayAttr` to a vector of OpFoldResult.
98  res.reserve(arrayAttr.size());
99  for (Attribute a : arrayAttr)
100  res.push_back(a);
101  return res;
102 }
103 
105  return IntegerAttr::get(IndexType::get(ctx), val);
106 }
107 
109  ArrayRef<int64_t> values) {
110  return llvm::to_vector(llvm::map_range(
111  values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
112 }
113 
114 /// If ofr is a constant integer or an IntegerAttr, return the integer.
115 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
116  // Case 1: Check for Constant integer.
117  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
118  APSInt intVal;
119  if (matchPattern(val, m_ConstantInt(&intVal)))
120  return intVal.getSExtValue();
121  return std::nullopt;
122  }
123  // Case 2: Check for IntegerAttr.
124  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
125  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
126  return intAttr.getValue().getSExtValue();
127  return std::nullopt;
128 }
129 
130 std::optional<SmallVector<int64_t>>
132  bool failed = false;
133  SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
134  auto cv = getConstantIntValue(ofr);
135  if (!cv.has_value())
136  failed = true;
137  return cv.value_or(0);
138  });
139  if (failed)
140  return std::nullopt;
141  return res;
142 }
143 
144 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
145  return getConstantIntValue(ofr) == value;
146 }
147 
149  return llvm::all_of(
150  ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
151 }
152 
154  ArrayRef<int64_t> values) {
155  if (ofrs.size() != values.size())
156  return false;
157  std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
158  return constOfrs && llvm::equal(constOfrs.value(), values);
159 }
160 
161 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
162 /// or the same SSA value.
163 /// Ignore integer bitwidth and type mismatch that come from the fact there is
164 /// no IndexAttr and that IndexType has no bitwidth.
166  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
167  if (cst1 && cst2 && *cst1 == *cst2)
168  return true;
169  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
170  v2 = llvm::dyn_cast_if_present<Value>(ofr2);
171  return v1 && v1 == v2;
172 }
173 
175  ArrayRef<OpFoldResult> ofrs2) {
176  if (ofrs1.size() != ofrs2.size())
177  return false;
178  for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
179  if (!isEqualConstantIntOrValue(ofr1, ofr2))
180  return false;
181  return true;
182 }
183 
184 /// Return a vector of OpFoldResults with the same size as staticValues, but all
185 /// elements for which ShapedType::isDynamic is true, will be replaced by
186 /// dynamicValues.
188  ValueRange dynamicValues,
189  MLIRContext *context) {
190  assert(dynamicValues.size() == static_cast<size_t>(llvm::count_if(
191  staticValues, ShapedType::isDynamic)) &&
192  "expected the rank of dynamic values to match the number of "
193  "values known to be dynamic");
195  res.reserve(staticValues.size());
196  unsigned numDynamic = 0;
197  unsigned count = static_cast<unsigned>(staticValues.size());
198  for (unsigned idx = 0; idx < count; ++idx) {
199  int64_t value = staticValues[idx];
200  res.push_back(ShapedType::isDynamic(value)
201  ? OpFoldResult{dynamicValues[numDynamic++]}
203  IntegerType::get(context, 64), staticValues[idx])});
204  }
205  return res;
206 }
208  ValueRange dynamicValues, Builder &b) {
209  return getMixedValues(staticValues, dynamicValues, b.getContext());
210 }
211 
212 /// Decompose a vector of mixed static or dynamic values into the corresponding
213 /// pair of arrays. This is the inverse function of `getMixedValues`.
214 std::pair<SmallVector<int64_t>, SmallVector<Value>>
216  SmallVector<int64_t> staticValues;
217  SmallVector<Value> dynamicValues;
218  for (const auto &it : mixedValues) {
219  if (auto attr = dyn_cast<Attribute>(it)) {
220  staticValues.push_back(cast<IntegerAttr>(attr).getInt());
221  } else {
222  staticValues.push_back(ShapedType::kDynamic);
223  dynamicValues.push_back(cast<Value>(it));
224  }
225  }
226  return {staticValues, dynamicValues};
227 }
228 
229 /// Helper to sort `values` according to matching `keys`.
230 template <typename K, typename V>
231 static SmallVector<V>
233  llvm::function_ref<bool(K, K)> compare) {
234  if (keys.empty())
235  return SmallVector<V>{values};
236  assert(keys.size() == values.size() && "unexpected mismatching sizes");
237  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
238  llvm::sort(indices,
239  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
240  SmallVector<V> res;
241  res.reserve(values.size());
242  for (int64_t i = 0, e = indices.size(); i < e; ++i)
243  res.push_back(values[indices[i]]);
244  return res;
245 }
246 
247 SmallVector<Value>
250  return getValuesSortedByKeyImpl(keys, values, compare);
251 }
252 
253 SmallVector<OpFoldResult>
256  return getValuesSortedByKeyImpl(keys, values, compare);
257 }
258 
259 SmallVector<int64_t>
262  return getValuesSortedByKeyImpl(keys, values, compare);
263 }
264 
265 /// Return the number of iterations for a loop with a lower bound `lb`, upper
266 /// bound `ub` and step `step`.
267 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
268  OpFoldResult step) {
269  if (lb == ub)
270  return 0;
271 
272  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
273  if (!lbConstant)
274  return std::nullopt;
275  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
276  if (!ubConstant)
277  return std::nullopt;
278  std::optional<int64_t> stepConstant = getConstantIntValue(step);
279  if (!stepConstant || *stepConstant == 0)
280  return std::nullopt;
281 
282  return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
283 }
284 
286  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
287  return ShapedType::isStatic(value) && value < 0;
288  });
289 }
290 
292  return llvm::none_of(strides, [](int64_t value) {
293  return ShapedType::isStatic(value) && value == 0;
294  });
295 }
296 
298  bool onlyNonNegative, bool onlyNonZero) {
299  bool valuesChanged = false;
300  for (OpFoldResult &ofr : ofrs) {
301  if (isa<Attribute>(ofr))
302  continue;
303  Attribute attr;
304  if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
305  // Note: All ofrs have index type.
306  if (onlyNonNegative && *getConstantIntValue(attr) < 0)
307  continue;
308  if (onlyNonZero && *getConstantIntValue(attr) == 0)
309  continue;
310  ofr = attr;
311  valuesChanged = true;
312  }
313  }
314  return success(valuesChanged);
315 }
316 
317 LogicalResult
319  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
320  /*onlyNonZero=*/false);
321 }
322 
324  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
325  /*onlyNonZero=*/true);
326 }
327 
328 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
MLIRContext * getContext() const
Definition: Builders.h:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.