MLIR  19.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Support/LLVM.h"
14 #include "llvm/ADT/APSInt.h"
15 
16 namespace mlir {
17 
19  if (!v)
20  return false;
21  if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
22  IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
23  return intAttr && intAttr.getValue().isZero();
24  }
25  if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
26  return cst.value() == 0;
27  return false;
28 }
29 
30 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
31  SmallVector<OpFoldResult>>
33  SmallVector<OpFoldResult> offsets, sizes, strides;
34  offsets.reserve(ranges.size());
35  sizes.reserve(ranges.size());
36  strides.reserve(ranges.size());
37  for (const auto &[offset, size, stride] : ranges) {
38  offsets.push_back(offset);
39  sizes.push_back(size);
40  strides.push_back(stride);
41  }
42  return std::make_tuple(offsets, sizes, strides);
43 }
44 
45 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
46 /// a) it is an IntegerAttr
47 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
48 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
49 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
50 /// come from an AttrSizedOperandSegments trait.
52  SmallVectorImpl<Value> &dynamicVec,
53  SmallVectorImpl<int64_t> &staticVec) {
54  auto v = llvm::dyn_cast_if_present<Value>(ofr);
55  if (!v) {
56  APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
57  staticVec.push_back(apInt.getSExtValue());
58  return;
59  }
60  dynamicVec.push_back(v);
61  staticVec.push_back(ShapedType::kDynamic);
62 }
63 
65  SmallVectorImpl<Value> &dynamicVec,
66  SmallVectorImpl<int64_t> &staticVec) {
67  for (OpFoldResult ofr : ofrs)
68  dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
69 }
70 
71 /// Given a value, try to extract a constant Attribute. If this fails, return
72 /// the original value.
74  if (!val)
75  return OpFoldResult();
76  Attribute attr;
77  if (matchPattern(val, m_Constant(&attr)))
78  return attr;
79  return val;
80 }
81 
82 /// Given an array of values, try to extract a constant Attribute from each
83 /// value. If this fails, return the original value.
85  return llvm::to_vector(
86  llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
87 }
88 
89 /// Convert `arrayAttr` to a vector of OpFoldResult.
92  res.reserve(arrayAttr.size());
93  for (Attribute a : arrayAttr)
94  res.push_back(a);
95  return res;
96 }
97 
99  return IntegerAttr::get(IndexType::get(ctx), val);
100 }
101 
103  ArrayRef<int64_t> values) {
104  return llvm::to_vector(llvm::map_range(
105  values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
106 }
107 
108 /// If ofr is a constant integer or an IntegerAttr, return the integer.
109 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
110  // Case 1: Check for Constant integer.
111  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
112  APSInt intVal;
113  if (matchPattern(val, m_ConstantInt(&intVal)))
114  return intVal.getSExtValue();
115  return std::nullopt;
116  }
117  // Case 2: Check for IntegerAttr.
118  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
119  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
120  return intAttr.getValue().getSExtValue();
121  return std::nullopt;
122 }
123 
124 std::optional<SmallVector<int64_t>>
126  bool failed = false;
127  SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
128  auto cv = getConstantIntValue(ofr);
129  if (!cv.has_value())
130  failed = true;
131  return cv.has_value() ? cv.value() : 0;
132  });
133  if (failed)
134  return std::nullopt;
135  return res;
136 }
137 
138 /// Return true if `ofr` is constant integer equal to `value`.
139 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
140  auto val = getConstantIntValue(ofr);
141  return val && *val == value;
142 }
143 
144 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
145 /// or the same SSA value.
146 /// Ignore integer bitwidth and type mismatch that come from the fact there is
147 /// no IndexAttr and that IndexType has no bitwidth.
149  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
150  if (cst1 && cst2 && *cst1 == *cst2)
151  return true;
152  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
153  v2 = llvm::dyn_cast_if_present<Value>(ofr2);
154  return v1 && v1 == v2;
155 }
156 
158  ArrayRef<OpFoldResult> ofrs2) {
159  if (ofrs1.size() != ofrs2.size())
160  return false;
161  for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
162  if (!isEqualConstantIntOrValue(ofr1, ofr2))
163  return false;
164  return true;
165 }
166 
167 /// Return a vector of OpFoldResults with the same size a staticValues, but all
168 /// elements for which ShapedType::isDynamic is true, will be replaced by
169 /// dynamicValues.
171  ValueRange dynamicValues, Builder &b) {
173  res.reserve(staticValues.size());
174  unsigned numDynamic = 0;
175  unsigned count = static_cast<unsigned>(staticValues.size());
176  for (unsigned idx = 0; idx < count; ++idx) {
177  int64_t value = staticValues[idx];
178  res.push_back(ShapedType::isDynamic(value)
179  ? OpFoldResult{dynamicValues[numDynamic++]}
180  : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
181  }
182  return res;
183 }
184 
185 /// Decompose a vector of mixed static or dynamic values into the corresponding
186 /// pair of arrays. This is the inverse function of `getMixedValues`.
187 std::pair<ArrayAttr, SmallVector<Value>>
189  const SmallVectorImpl<OpFoldResult> &mixedValues) {
190  SmallVector<int64_t> staticValues;
191  SmallVector<Value> dynamicValues;
192  for (const auto &it : mixedValues) {
193  if (it.is<Attribute>()) {
194  staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
195  } else {
196  staticValues.push_back(ShapedType::kDynamic);
197  dynamicValues.push_back(it.get<Value>());
198  }
199  }
200  return {b.getI64ArrayAttr(staticValues), dynamicValues};
201 }
202 
203 /// Helper to sort `values` according to matching `keys`.
204 template <typename K, typename V>
205 static SmallVector<V>
207  llvm::function_ref<bool(K, K)> compare) {
208  if (keys.empty())
209  return SmallVector<V>{values};
210  assert(keys.size() == values.size() && "unexpected mismatching sizes");
211  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
212  std::sort(indices.begin(), indices.end(),
213  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
214  SmallVector<V> res;
215  res.reserve(values.size());
216  for (int64_t i = 0, e = indices.size(); i < e; ++i)
217  res.push_back(values[indices[i]]);
218  return res;
219 }
220 
221 SmallVector<Value>
224  return getValuesSortedByKeyImpl(keys, values, compare);
225 }
226 
227 SmallVector<OpFoldResult>
230  return getValuesSortedByKeyImpl(keys, values, compare);
231 }
232 
233 SmallVector<int64_t>
236  return getValuesSortedByKeyImpl(keys, values, compare);
237 }
238 
239 /// Return the number of iterations for a loop with a lower bound `lb`, upper
240 /// bound `ub` and step `step`.
241 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
242  OpFoldResult step) {
243  if (lb == ub)
244  return 0;
245 
246  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
247  if (!lbConstant)
248  return std::nullopt;
249  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
250  if (!ubConstant)
251  return std::nullopt;
252  std::optional<int64_t> stepConstant = getConstantIntValue(step);
253  if (!stepConstant)
254  return std::nullopt;
255 
256  return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
257 }
258 
260  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
261  return !ShapedType::isDynamic(value) && value < 0;
262  });
263 }
264 
266  return llvm::none_of(strides, [](int64_t value) {
267  return !ShapedType::isDynamic(value) && value == 0;
268  });
269 }
270 
272  bool onlyNonNegative, bool onlyNonZero) {
273  bool valuesChanged = false;
274  for (OpFoldResult &ofr : ofrs) {
275  if (ofr.is<Attribute>())
276  continue;
277  Attribute attr;
278  if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
279  // Note: All ofrs have index type.
280  if (onlyNonNegative && *getConstantIntValue(attr) < 0)
281  continue;
282  if (onlyNonZero && *getConstantIntValue(attr) == 0)
283  continue;
284  ofr = attr;
285  valuesChanged = true;
286  }
287  }
288  return success(valuesChanged);
289 }
290 
291 LogicalResult
293  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
294  /*onlyNonZero=*/false);
295 }
296 
298  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
299  /*onlyNonZero=*/true);
300 }
301 
302 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::pair< ArrayAttr, SmallVector< Value > > decomposeMixedValues(Builder &b, const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.