MLIR  18.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Support/LLVM.h"
14 #include "llvm/ADT/APSInt.h"
15 
16 namespace mlir {
17 
19  if (!v)
20  return false;
21  if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
22  IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
23  return intAttr && intAttr.getValue().isZero();
24  }
25  if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
26  return cst.value() == 0;
27  return false;
28 }
29 
30 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
31  SmallVector<OpFoldResult>>
33  SmallVector<OpFoldResult> offsets, sizes, strides;
34  offsets.reserve(ranges.size());
35  sizes.reserve(ranges.size());
36  strides.reserve(ranges.size());
37  for (const auto &[offset, size, stride] : ranges) {
38  offsets.push_back(offset);
39  sizes.push_back(size);
40  strides.push_back(stride);
41  }
42  return std::make_tuple(offsets, sizes, strides);
43 }
44 
45 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
46 /// a) it is an IntegerAttr
47 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
48 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
49 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
50 /// come from an AttrSizedOperandSegments trait.
52  SmallVectorImpl<Value> &dynamicVec,
53  SmallVectorImpl<int64_t> &staticVec) {
54  auto v = llvm::dyn_cast_if_present<Value>(ofr);
55  if (!v) {
56  APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
57  staticVec.push_back(apInt.getSExtValue());
58  return;
59  }
60  dynamicVec.push_back(v);
61  staticVec.push_back(ShapedType::kDynamic);
62 }
63 
65  SmallVectorImpl<Value> &dynamicVec,
66  SmallVectorImpl<int64_t> &staticVec) {
67  for (OpFoldResult ofr : ofrs)
68  dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
69 }
70 
71 /// Given a value, try to extract a constant Attribute. If this fails, return
72 /// the original value.
74  if (!val)
75  return OpFoldResult();
76  Attribute attr;
77  if (matchPattern(val, m_Constant(&attr)))
78  return attr;
79  return val;
80 }
81 
82 /// Given an array of values, try to extract a constant Attribute from each
83 /// value. If this fails, return the original value.
85  return llvm::to_vector(
86  llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
87 }
88 
89 /// Convert `arrayAttr` to a vector of OpFoldResult.
92  res.reserve(arrayAttr.size());
93  for (Attribute a : arrayAttr)
94  res.push_back(a);
95  return res;
96 }
97 
99  return IntegerAttr::get(IndexType::get(ctx), val);
100 }
101 
103  ArrayRef<int64_t> values) {
104  return llvm::to_vector(llvm::map_range(
105  values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
106 }
107 
108 /// If ofr is a constant integer or an IntegerAttr, return the integer.
109 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
110  // Case 1: Check for Constant integer.
111  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
112  APSInt intVal;
113  if (matchPattern(val, m_ConstantInt(&intVal)))
114  return intVal.getSExtValue();
115  return std::nullopt;
116  }
117  // Case 2: Check for IntegerAttr.
118  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
119  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
120  return intAttr.getValue().getSExtValue();
121  return std::nullopt;
122 }
123 
124 std::optional<SmallVector<int64_t>>
126  bool failed = false;
127  SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
128  auto cv = getConstantIntValue(ofr);
129  if (!cv.has_value())
130  failed = true;
131  return cv.has_value() ? cv.value() : 0;
132  });
133  if (failed)
134  return std::nullopt;
135  return res;
136 }
137 
138 /// Return true if `ofr` is constant integer equal to `value`.
139 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
140  auto val = getConstantIntValue(ofr);
141  return val && *val == value;
142 }
143 
144 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
145 /// or the same SSA value.
146 /// Ignore integer bitwidth and type mismatch that come from the fact there is
147 /// no IndexAttr and that IndexType has no bitwidth.
149  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
150  if (cst1 && cst2 && *cst1 == *cst2)
151  return true;
152  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
153  v2 = llvm::dyn_cast_if_present<Value>(ofr2);
154  return v1 && v1 == v2;
155 }
156 
158  ArrayRef<OpFoldResult> ofrs2) {
159  if (ofrs1.size() != ofrs2.size())
160  return false;
161  for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
162  if (!isEqualConstantIntOrValue(ofr1, ofr2))
163  return false;
164  return true;
165 }
166 
167 /// Return a vector of OpFoldResults with the same size a staticValues, but all
168 /// elements for which ShapedType::isDynamic is true, will be replaced by
169 /// dynamicValues.
171  ValueRange dynamicValues, Builder &b) {
173  res.reserve(staticValues.size());
174  unsigned numDynamic = 0;
175  unsigned count = static_cast<unsigned>(staticValues.size());
176  for (unsigned idx = 0; idx < count; ++idx) {
177  int64_t value = staticValues[idx];
178  res.push_back(ShapedType::isDynamic(value)
179  ? OpFoldResult{dynamicValues[numDynamic++]}
180  : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
181  }
182  return res;
183 }
184 
185 /// Decompose a vector of mixed static or dynamic values into the corresponding
186 /// pair of arrays. This is the inverse function of `getMixedValues`.
187 std::pair<ArrayAttr, SmallVector<Value>>
189  const SmallVectorImpl<OpFoldResult> &mixedValues) {
190  SmallVector<int64_t> staticValues;
191  SmallVector<Value> dynamicValues;
192  for (const auto &it : mixedValues) {
193  if (it.is<Attribute>()) {
194  staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
195  } else {
196  staticValues.push_back(ShapedType::kDynamic);
197  dynamicValues.push_back(it.get<Value>());
198  }
199  }
200  return {b.getI64ArrayAttr(staticValues), dynamicValues};
201 }
202 
203 /// Helper to sort `values` according to matching `keys`.
204 template <typename K, typename V>
205 static SmallVector<V>
207  llvm::function_ref<bool(K, K)> compare) {
208  if (keys.empty())
209  return SmallVector<V>{values};
210  assert(keys.size() == values.size() && "unexpected mismatching sizes");
211  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
212  std::sort(indices.begin(), indices.end(),
213  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
214  SmallVector<V> res;
215  res.reserve(values.size());
216  for (int64_t i = 0, e = indices.size(); i < e; ++i)
217  res.push_back(values[indices[i]]);
218  return res;
219 }
220 
221 SmallVector<Value>
224  return getValuesSortedByKeyImpl(keys, values, compare);
225 }
226 
227 SmallVector<OpFoldResult>
230  return getValuesSortedByKeyImpl(keys, values, compare);
231 }
232 
233 SmallVector<int64_t>
236  return getValuesSortedByKeyImpl(keys, values, compare);
237 }
238 
239 /// Return the number of iterations for a loop with a lower bound `lb`, upper
240 /// bound `ub` and step `step`.
241 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
242  OpFoldResult step) {
243  if (lb == ub)
244  return 0;
245 
246  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
247  if (!lbConstant)
248  return std::nullopt;
249  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
250  if (!ubConstant)
251  return std::nullopt;
252  std::optional<int64_t> stepConstant = getConstantIntValue(step);
253  if (!stepConstant)
254  return std::nullopt;
255 
256  return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
257 }
258 
260  bool valuesChanged = false;
261  for (OpFoldResult &ofr : ofrs) {
262  if (ofr.is<Attribute>())
263  continue;
264  Attribute attr;
265  if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
266  ofr = attr;
267  valuesChanged = true;
268  }
269  }
270  return success(valuesChanged);
271 }
272 
273 } // namespace mlir
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:63
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::pair< ArrayAttr, SmallVector< Value > > decomposeMixedValues(Builder &b, const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs)
Returns "success" when any of the elements in ofrs is a constant value.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.