MLIR 23.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/Support/LLVM.h"
13#include "llvm/ADT/APSInt.h"
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/ADT/SmallVectorExtras.h"
16#include "llvm/Support/DebugLog.h"
17#include "llvm/Support/MathExtras.h"
18
19namespace mlir {
20
22
24 if (auto attr = dyn_cast<Attribute>(v)) {
25 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
26 return floatAttr.getValue().isZero();
27 return false;
28 }
29 return matchPattern(cast<Value>(v), m_AnyZeroFloat());
30}
31
35
37
38std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
39 SmallVector<OpFoldResult>>
41 SmallVector<OpFoldResult> offsets, sizes, strides;
42 offsets.reserve(ranges.size());
43 sizes.reserve(ranges.size());
44 strides.reserve(ranges.size());
45 for (const auto &[offset, size, stride] : ranges) {
46 offsets.push_back(offset);
47 sizes.push_back(size);
48 strides.push_back(stride);
49 }
50 return std::make_tuple(offsets, sizes, strides);
51}
52
53/// Helper function to dispatch an OpFoldResult into `staticVec` if:
54/// a) it is an IntegerAttr
55/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
56/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
57/// `staticVec`. This is useful to extract mixed static and dynamic entries that
58/// come from an AttrSizedOperandSegments trait.
60 SmallVectorImpl<Value> &dynamicVec,
61 SmallVectorImpl<int64_t> &staticVec) {
62 auto v = llvm::dyn_cast_if_present<Value>(ofr);
63 if (!v) {
64 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
65 staticVec.push_back(apInt.getSExtValue());
66 return;
67 }
68 dynamicVec.push_back(v);
69 staticVec.push_back(ShapedType::kDynamic);
70}
71
72std::pair<int64_t, OpFoldResult>
74 int64_t tileSizeForShape =
75 getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
76
77 OpFoldResult tileSizeOfrSimplified =
78 (tileSizeForShape != ShapedType::kDynamic)
79 ? b.getIndexAttr(tileSizeForShape)
80 : tileSizeOfr;
81
82 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
83 tileSizeOfrSimplified);
84}
85
87 SmallVectorImpl<Value> &dynamicVec,
88 SmallVectorImpl<int64_t> &staticVec) {
89 for (OpFoldResult ofr : ofrs)
90 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
91}
92
93/// Given a value, try to extract a constant Attribute. If this fails, return
94/// the original value.
96 if (!val)
97 return OpFoldResult();
98 Attribute attr;
99 if (matchPattern(val, m_Constant(&attr)))
100 return attr;
101 return val;
102}
103
104/// Given an array of values, try to extract a constant Attribute from each
105/// value. If this fails, return the original value.
107 return llvm::map_to_vector(values,
108 [](Value v) { return getAsOpFoldResult(v); });
109}
110
111/// Convert `arrayAttr` to a vector of OpFoldResult.
114 res.reserve(arrayAttr.size());
115 for (Attribute a : arrayAttr)
116 res.push_back(a);
117 return res;
118}
119
121 return IntegerAttr::get(IndexType::get(ctx), val);
122}
123
125 ArrayRef<int64_t> values) {
126 return llvm::map_to_vector(
127 values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); });
128}
129
130/// If ofr is a constant integer or an IntegerAttr, return the integer.
131/// The boolean indicates whether the value is an index type.
132std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
133 // Case 1: Check for Constant integer.
134 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
135 APInt intVal;
136 if (matchPattern(val, m_ConstantInt(&intVal)))
137 return std::make_pair(intVal, val.getType().isIndex());
138 return std::nullopt;
139 }
140 // Case 2: Check for IntegerAttr.
141 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
142 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
143 return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
144 return std::nullopt;
145}
146
147/// If ofr is a constant integer or an IntegerAttr, return the integer.
148std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
149 std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
150 if (!apInt)
151 return std::nullopt;
152 return apInt->first.getSExtValue();
153}
154
155std::optional<SmallVector<int64_t>>
158 res.reserve(ofrs.size());
159 for (OpFoldResult ofr : ofrs) {
160 auto cv = getConstantIntValue(ofr);
161 if (!cv.has_value())
162 return std::nullopt;
163 res.push_back(cv.value());
164 }
165 return res;
166}
167
169 return getConstantIntValue(ofr) == value;
170}
171
173 return llvm::all_of(
174 ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
175}
176
178 ArrayRef<int64_t> values) {
179 if (ofrs.size() != values.size())
180 return false;
181 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
182 return constOfrs && llvm::equal(constOfrs.value(), values);
183}
184
185/// Return true if ofr1 and ofr2 are the same integer constant attribute values
186/// or the same SSA value.
187/// Ignore integer bitwidth and type mismatch that come from the fact there is
188/// no IndexAttr and that IndexType has no bitwidth.
190 auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
191 if (cst1 && cst2 && *cst1 == *cst2)
192 return true;
193 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
194 v2 = llvm::dyn_cast_if_present<Value>(ofr2);
195 return v1 && v1 == v2;
196}
197
200 if (ofrs1.size() != ofrs2.size())
201 return false;
202 for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
203 if (!isEqualConstantIntOrValue(ofr1, ofr2))
204 return false;
205 return true;
206}
207
208/// Return a vector of OpFoldResults with the same size as staticValues, but all
209/// elements for which ShapedType::isDynamic is true, will be replaced by
210/// dynamicValues.
212 ValueRange dynamicValues,
213 MLIRContext *context) {
214 assert(dynamicValues.size() == static_cast<size_t>(llvm::count_if(
215 staticValues, ShapedType::isDynamic)) &&
216 "expected the rank of dynamic values to match the number of "
217 "values known to be dynamic");
219 res.reserve(staticValues.size());
220 unsigned numDynamic = 0;
221 unsigned count = static_cast<unsigned>(staticValues.size());
222 for (unsigned idx = 0; idx < count; ++idx) {
223 int64_t value = staticValues[idx];
224 res.push_back(ShapedType::isDynamic(value)
225 ? OpFoldResult{dynamicValues[numDynamic++]}
226 : OpFoldResult{IntegerAttr::get(
227 IntegerType::get(context, 64), staticValues[idx])});
228 }
229 return res;
230}
232 ValueRange dynamicValues, Builder &b) {
233 return getMixedValues(staticValues, dynamicValues, b.getContext());
234}
235
236/// Decompose a vector of mixed static or dynamic values into the corresponding
237/// pair of arrays. This is the inverse function of `getMixedValues`.
238std::pair<SmallVector<int64_t>, SmallVector<Value>>
240 SmallVector<int64_t> staticValues;
241 SmallVector<Value> dynamicValues;
242 for (const auto &it : mixedValues) {
243 if (auto attr = dyn_cast<Attribute>(it)) {
244 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
245 } else {
246 staticValues.push_back(ShapedType::kDynamic);
247 dynamicValues.push_back(cast<Value>(it));
248 }
249 }
250 return {staticValues, dynamicValues};
251}
252
253/// Helper to sort `values` according to matching `keys`.
254template <typename K, typename V>
255static SmallVector<V>
257 llvm::function_ref<bool(K, K)> compare) {
258 if (keys.empty())
259 return SmallVector<V>{values};
260 assert(keys.size() == values.size() && "unexpected mismatching sizes");
261 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
262 llvm::sort(indices,
263 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
264 SmallVector<V> res;
265 res.reserve(values.size());
266 for (int64_t i = 0, e = indices.size(); i < e; ++i)
267 res.push_back(values[indices[i]]);
268 return res;
269}
270
271SmallVector<Value>
273 llvm::function_ref<bool(Attribute, Attribute)> compare) {
274 return getValuesSortedByKeyImpl(keys, values, compare);
275}
276
277SmallVector<OpFoldResult>
279 llvm::function_ref<bool(Attribute, Attribute)> compare) {
280 return getValuesSortedByKeyImpl(keys, values, compare);
281}
282
283SmallVector<int64_t>
285 llvm::function_ref<bool(Attribute, Attribute)> compare) {
286 return getValuesSortedByKeyImpl(keys, values, compare);
287}
288
289/// Return the number of iterations for a loop with a lower bound `lb`, upper
290/// bound `ub` and step `step`.
291std::optional<APInt> constantTripCount(
292 OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
293 llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
294 computeUbMinusLb) {
295 // This is the bitwidth used to return 0 when loop does not execute.
296 // We infer it from the type of the bound if it isn't an index type.
297 auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
298 if (auto intAttr =
299 dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
300 if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
301 return std::make_tuple(intType.getWidth(), intType.isIndex());
302 } else {
303 auto val = cast<Value>(ofr);
304 if (auto intType = dyn_cast<IntegerType>(val.getType()))
305 return std::make_tuple(intType.getWidth(), intType.isIndex());
306 }
307 return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
308 };
309 auto [bitwidth, isIndex] = getBitwidth(lb);
310 // This would better be an assert, but unfortunately it breaks scf.for_all
311 // which is missing attributes and SSA value optionally for its bounds, and
312 // uses Index type for the dynamic bounds but i64 for the static bounds. This
313 // is broken...
314 if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
315 LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
316 << lb;
317 return std::nullopt;
318 }
319 if (lb == ub) {
320 // Fast path: LB == UB. The loop has zero iterations.
321 // Note: LB and UB could match at runtime, even though they are different
322 // SSA values. That case cannot be detected here.
323 return APInt(bitwidth, 0);
324 }
325
326 std::optional<std::pair<APInt, bool>> maybeStepCst =
328
329 if (maybeStepCst) {
330 auto &stepCst = maybeStepCst->first;
331 assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
332 "step must have the same bitwidth as lb and ub");
333 if (stepCst.isZero()) {
334 // Step is zero. If LB and UB match, we have zero iterations. Otherwise,
335 // we have an infinite number of iterations. We cannot tell for sure which
336 // case applies, so the static trip count is unknown.
337 return std::nullopt;
338 }
339 }
340
341 if (isIndex) {
342 LDBG()
343 << "Computing loop trip count for index type may break with overflow";
344 // TODO: we can't compute the trip count for index type. We should fix this
345 // but too many tests are failing right now.
346 // return {};
347 }
348
349 /// Compute the difference between the upper and lower bound: either from the
350 /// constant value or using the computeUbMinusLb callback.
351 llvm::APSInt diff;
352 std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
353 std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
354 if (maybeLbCst) {
355 // If one of the bounds is not a constant, we can't compute the trip count.
356 if (!maybeUbCst)
357 return std::nullopt;
358 APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
359 APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
360 if (ubCst <= lbCst) {
361 LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
362 << lbCst.getBitWidth() << ") <= " << ubCst << "("
363 << ubCst.getBitWidth() << "), "
364 << (isSigned ? "isSigned" : "isUnsigned") << ")";
365 return APInt(bitwidth, 0);
366 }
367 // Compute the difference. Since we've already checked that ub > lb, the
368 // result can be interpreted as an unsigned value without overflow concerns.
369 diff = ubCst - lbCst;
370 // Convert diff to unsigned. This handles cases like i8: ub=127, lb=-128
371 // where the subtraction yields 255, which wraps to -1 in signed i8 but is
372 // correctly represented as 255 when interpreted as unsigned.
373 diff.setIsUnsigned(true);
374 } else {
375 if (maybeUbCst)
376 return std::nullopt;
377
378 /// Non-constant bound, let's try to compute the difference between the
379 /// upper and lower bound
380 std::optional<llvm::APSInt> maybeDiff =
381 computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
382 if (!maybeDiff)
383 return std::nullopt;
384 diff = *maybeDiff;
385 }
386 LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
387 << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
388 if (diff.isNegative()) {
389 LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
390 return APInt(bitwidth, 0);
391 }
392 if (!maybeStepCst) {
393 LDBG()
394 << "constantTripCount can't be computed because step is not a constant";
395 return std::nullopt;
396 }
397 auto &stepCst = maybeStepCst->first;
398 // For signed loops, a negative step size could indicate an infinite number of
399 // iterations.
400 if (isSigned && stepCst.isSignBitSet()) {
401 LDBG() << "constantTripCount is infinite because step is negative";
402 return std::nullopt;
403 }
404
405 // Both diff and step are non-negative at this point (negative steps are
406 // rejected earlier), so we use unsigned division regardless of the loop
407 // comparison signedness.
408 llvm::APInt tripCount = diff.udiv(stepCst);
409 llvm::APInt remainder = diff.urem(stepCst);
410 if (!remainder.isZero())
411 tripCount = tripCount + 1;
412
413 LDBG() << "constantTripCount found: " << tripCount;
414 return tripCount;
415}
416
418 return llvm::none_of(sizesOrOffsets, [](int64_t value) {
419 return ShapedType::isStatic(value) && value < 0;
420 });
421}
422
424 return llvm::none_of(strides, [](int64_t value) {
425 return ShapedType::isStatic(value) && value == 0;
426 });
427}
428
430 bool onlyNonNegative, bool onlyNonZero) {
431 bool valuesChanged = false;
432 for (OpFoldResult &ofr : ofrs) {
433 if (isa<Attribute>(ofr))
434 continue;
435 Attribute attr;
436 if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
437 // Note: All ofrs have index type.
438 if (onlyNonNegative && *getConstantIntValue(attr) < 0)
439 continue;
440 if (onlyNonZero && *getConstantIntValue(attr) == 0)
441 continue;
442 ofr = attr;
443 valuesChanged = true;
444 }
445 }
446 return success(valuesChanged);
447}
448
449LogicalResult
451 return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
452 /*onlyNonZero=*/false);
453}
454
456 return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
457 /*onlyNonZero=*/true);
458}
459
460} // namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
bool isZeroFloat(OpFoldResult v)
Return "true" if v is a float value/attribute with constant value 0.0.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.