MLIR 22.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/Support/LLVM.h"
13#include "llvm/ADT/APSInt.h"
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/Support/DebugLog.h"
16#include "llvm/Support/MathExtras.h"
17
18namespace mlir {
19
21
23
24std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
27 SmallVector<OpFoldResult> offsets, sizes, strides;
28 offsets.reserve(ranges.size());
29 sizes.reserve(ranges.size());
30 strides.reserve(ranges.size());
31 for (const auto &[offset, size, stride] : ranges) {
32 offsets.push_back(offset);
33 sizes.push_back(size);
34 strides.push_back(stride);
35 }
36 return std::make_tuple(offsets, sizes, strides);
37}
38
39/// Helper function to dispatch an OpFoldResult into `staticVec` if:
40/// a) it is an IntegerAttr
41/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
42/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
43/// `staticVec`. This is useful to extract mixed static and dynamic entries that
44/// come from an AttrSizedOperandSegments trait.
46 SmallVectorImpl<Value> &dynamicVec,
47 SmallVectorImpl<int64_t> &staticVec) {
48 auto v = llvm::dyn_cast_if_present<Value>(ofr);
49 if (!v) {
50 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
51 staticVec.push_back(apInt.getSExtValue());
52 return;
53 }
54 dynamicVec.push_back(v);
55 staticVec.push_back(ShapedType::kDynamic);
56}
57
58std::pair<int64_t, OpFoldResult>
60 int64_t tileSizeForShape =
61 getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
62
63 OpFoldResult tileSizeOfrSimplified =
64 (tileSizeForShape != ShapedType::kDynamic)
65 ? b.getIndexAttr(tileSizeForShape)
66 : tileSizeOfr;
67
68 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
69 tileSizeOfrSimplified);
70}
71
73 SmallVectorImpl<Value> &dynamicVec,
74 SmallVectorImpl<int64_t> &staticVec) {
75 for (OpFoldResult ofr : ofrs)
76 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
77}
78
79/// Given a value, try to extract a constant Attribute. If this fails, return
80/// the original value.
82 if (!val)
83 return OpFoldResult();
84 Attribute attr;
85 if (matchPattern(val, m_Constant(&attr)))
86 return attr;
87 return val;
88}
89
90/// Given an array of values, try to extract a constant Attribute from each
91/// value. If this fails, return the original value.
93 return llvm::to_vector(
94 llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
95}
96
97/// Convert `arrayAttr` to a vector of OpFoldResult.
100 res.reserve(arrayAttr.size());
101 for (Attribute a : arrayAttr)
102 res.push_back(a);
103 return res;
104}
105
107 return IntegerAttr::get(IndexType::get(ctx), val);
108}
109
111 ArrayRef<int64_t> values) {
112 return llvm::to_vector(llvm::map_range(
113 values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
114}
115
116/// If ofr is a constant integer or an IntegerAttr, return the integer.
117/// The boolean indicates whether the value is an index type.
118std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
119 // Case 1: Check for Constant integer.
120 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
121 APInt intVal;
122 if (matchPattern(val, m_ConstantInt(&intVal)))
123 return std::make_pair(intVal, val.getType().isIndex());
124 return std::nullopt;
125 }
126 // Case 2: Check for IntegerAttr.
127 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
128 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
129 return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
130 return std::nullopt;
131}
132
133/// If ofr is a constant integer or an IntegerAttr, return the integer.
134std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
135 std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
136 if (!apInt)
137 return std::nullopt;
138 return apInt->first.getSExtValue();
139}
140
141std::optional<SmallVector<int64_t>>
143 bool failed = false;
144 SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
145 auto cv = getConstantIntValue(ofr);
146 if (!cv.has_value())
147 failed = true;
148 return cv.value_or(0);
149 });
150 if (failed)
151 return std::nullopt;
152 return res;
153}
154
156 return getConstantIntValue(ofr) == value;
157}
158
160 return llvm::all_of(
161 ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
162}
163
165 ArrayRef<int64_t> values) {
166 if (ofrs.size() != values.size())
167 return false;
168 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
169 return constOfrs && llvm::equal(constOfrs.value(), values);
170}
171
172/// Return true if ofr1 and ofr2 are the same integer constant attribute values
173/// or the same SSA value.
174/// Ignore integer bitwidth and type mismatch that come from the fact there is
175/// no IndexAttr and that IndexType has no bitwidth.
177 auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
178 if (cst1 && cst2 && *cst1 == *cst2)
179 return true;
180 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
181 v2 = llvm::dyn_cast_if_present<Value>(ofr2);
182 return v1 && v1 == v2;
183}
184
187 if (ofrs1.size() != ofrs2.size())
188 return false;
189 for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
190 if (!isEqualConstantIntOrValue(ofr1, ofr2))
191 return false;
192 return true;
193}
194
195/// Return a vector of OpFoldResults with the same size as staticValues, but all
196/// elements for which ShapedType::isDynamic is true, will be replaced by
197/// dynamicValues.
199 ValueRange dynamicValues,
200 MLIRContext *context) {
201 assert(dynamicValues.size() == static_cast<size_t>(llvm::count_if(
202 staticValues, ShapedType::isDynamic)) &&
203 "expected the rank of dynamic values to match the number of "
204 "values known to be dynamic");
206 res.reserve(staticValues.size());
207 unsigned numDynamic = 0;
208 unsigned count = static_cast<unsigned>(staticValues.size());
209 for (unsigned idx = 0; idx < count; ++idx) {
210 int64_t value = staticValues[idx];
211 res.push_back(ShapedType::isDynamic(value)
212 ? OpFoldResult{dynamicValues[numDynamic++]}
213 : OpFoldResult{IntegerAttr::get(
214 IntegerType::get(context, 64), staticValues[idx])});
215 }
216 return res;
217}
219 ValueRange dynamicValues, Builder &b) {
220 return getMixedValues(staticValues, dynamicValues, b.getContext());
221}
222
223/// Decompose a vector of mixed static or dynamic values into the corresponding
224/// pair of arrays. This is the inverse function of `getMixedValues`.
225std::pair<SmallVector<int64_t>, SmallVector<Value>>
227 SmallVector<int64_t> staticValues;
228 SmallVector<Value> dynamicValues;
229 for (const auto &it : mixedValues) {
230 if (auto attr = dyn_cast<Attribute>(it)) {
231 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
232 } else {
233 staticValues.push_back(ShapedType::kDynamic);
234 dynamicValues.push_back(cast<Value>(it));
235 }
236 }
237 return {staticValues, dynamicValues};
238}
239
240/// Helper to sort `values` according to matching `keys`.
241template <typename K, typename V>
242static SmallVector<V>
244 llvm::function_ref<bool(K, K)> compare) {
245 if (keys.empty())
246 return SmallVector<V>{values};
247 assert(keys.size() == values.size() && "unexpected mismatching sizes");
248 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
249 llvm::sort(indices,
250 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
251 SmallVector<V> res;
252 res.reserve(values.size());
253 for (int64_t i = 0, e = indices.size(); i < e; ++i)
254 res.push_back(values[indices[i]]);
255 return res;
256}
257
260 llvm::function_ref<bool(Attribute, Attribute)> compare) {
261 return getValuesSortedByKeyImpl(keys, values, compare);
262}
263
266 llvm::function_ref<bool(Attribute, Attribute)> compare) {
267 return getValuesSortedByKeyImpl(keys, values, compare);
268}
269
272 llvm::function_ref<bool(Attribute, Attribute)> compare) {
273 return getValuesSortedByKeyImpl(keys, values, compare);
274}
275
276/// Return the number of iterations for a loop with a lower bound `lb`, upper
277/// bound `ub` and step `step`.
278std::optional<APInt> constantTripCount(
279 OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
280 llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
281 computeUbMinusLb) {
282 // This is the bitwidth used to return 0 when loop does not execute.
283 // We infer it from the type of the bound if it isn't an index type.
284 auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
285 if (auto intAttr =
286 dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
287 if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
288 return std::make_tuple(intType.getWidth(), intType.isIndex());
289 } else {
290 auto val = cast<Value>(ofr);
291 if (auto intType = dyn_cast<IntegerType>(val.getType()))
292 return std::make_tuple(intType.getWidth(), intType.isIndex());
293 }
294 return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
295 };
296 auto [bitwidth, isIndex] = getBitwidth(lb);
297 // This would better be an assert, but unfortunately it breaks scf.for_all
298 // which is missing attributes and SSA value optionally for its bounds, and
299 // uses Index type for the dynamic bounds but i64 for the static bounds. This
300 // is broken...
301 if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
302 LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
303 << lb;
304 return std::nullopt;
305 }
306 if (lb == ub)
307 return APInt(bitwidth, 0);
308
309 std::optional<std::pair<APInt, bool>> maybeStepCst =
311
312 if (maybeStepCst) {
313 auto &stepCst = maybeStepCst->first;
314 assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
315 "step must have the same bitwidth as lb and ub");
316 if (stepCst.isZero())
317 return stepCst;
318 if (stepCst.isNegative())
319 return APInt(bitwidth, 0);
320 }
321
322 if (isIndex) {
323 LDBG()
324 << "Computing loop trip count for index type may break with overflow";
325 // TODO: we can't compute the trip count for index type. We should fix this
326 // but too many tests are failing right now.
327 // return {};
328 }
329
330 /// Compute the difference between the upper and lower bound: either from the
331 /// constant value or using the computeUbMinusLb callback.
332 llvm::APSInt diff;
333 std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
334 std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
335 if (maybeLbCst) {
336 // If one of the bounds is not a constant, we can't compute the trip count.
337 if (!maybeUbCst)
338 return std::nullopt;
339 APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
340 APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
341 if (!maybeUbCst)
342 return std::nullopt;
343 if (ubCst <= lbCst) {
344 LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
345 << lbCst.getBitWidth() << ") <= " << ubCst << "("
346 << ubCst.getBitWidth() << "), "
347 << (isSigned ? "isSigned" : "isUnsigned") << ")";
348 return APInt(bitwidth, 0);
349 }
350 diff = ubCst - lbCst;
351 } else {
352 if (maybeUbCst)
353 return std::nullopt;
354
355 /// Non-constant bound, let's try to compute the difference between the
356 /// upper and lower bound
357 std::optional<llvm::APSInt> maybeDiff =
358 computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
359 if (!maybeDiff)
360 return std::nullopt;
361 diff = *maybeDiff;
362 }
363 LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
364 << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
365 if (diff.isNegative()) {
366 LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
367 return APInt(bitwidth, 0);
368 }
369 if (!maybeStepCst) {
370 LDBG()
371 << "constantTripCount can't be computed because step is not a constant";
372 return std::nullopt;
373 }
374 auto &stepCst = maybeStepCst->first;
375 llvm::APInt tripCount = diff.sdiv(stepCst);
376 llvm::APInt r = diff.srem(stepCst);
377 if (!r.isZero())
378 tripCount = tripCount + 1;
379 LDBG() << "constantTripCount found: " << tripCount;
380 return tripCount;
381}
382
384 return llvm::none_of(sizesOrOffsets, [](int64_t value) {
385 return ShapedType::isStatic(value) && value < 0;
386 });
387}
388
390 return llvm::none_of(strides, [](int64_t value) {
391 return ShapedType::isStatic(value) && value == 0;
392 });
393}
394
396 bool onlyNonNegative, bool onlyNonZero) {
397 bool valuesChanged = false;
398 for (OpFoldResult &ofr : ofrs) {
399 if (isa<Attribute>(ofr))
400 continue;
401 Attribute attr;
402 if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
403 // Note: All ofrs have index type.
404 if (onlyNonNegative && *getConstantIntValue(attr) < 0)
405 continue;
406 if (onlyNonZero && *getConstantIntValue(attr) == 0)
407 continue;
408 ofr = attr;
409 valuesChanged = true;
410 }
411 }
412 return success(valuesChanged);
413}
414
415LogicalResult
417 return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
418 /*onlyNonZero=*/false);
419}
420
422 return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
423 /*onlyNonZero=*/true);
424}
425
426} // namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.