MLIR 23.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/Support/LLVM.h"
13#include "llvm/ADT/APSInt.h"
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/ADT/SmallVectorExtras.h"
16#include "llvm/Support/DebugLog.h"
17#include "llvm/Support/MathExtras.h"
18
19namespace mlir {
20
22
24 if (auto attr = dyn_cast<Attribute>(v)) {
25 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
26 return floatAttr.getValue().isZero();
27 return false;
28 }
29 return matchPattern(cast<Value>(v), m_AnyZeroFloat());
30}
31
35
37
38std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
41 SmallVector<OpFoldResult> offsets, sizes, strides;
42 offsets.reserve(ranges.size());
43 sizes.reserve(ranges.size());
44 strides.reserve(ranges.size());
45 for (const auto &[offset, size, stride] : ranges) {
46 offsets.push_back(offset);
47 sizes.push_back(size);
48 strides.push_back(stride);
49 }
50 return std::make_tuple(offsets, sizes, strides);
51}
52
53/// Helper function to dispatch an OpFoldResult into `staticVec` if:
54/// a) it is an IntegerAttr
55/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
56/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
57/// `staticVec`. This is useful to extract mixed static and dynamic entries that
58/// come from an AttrSizedOperandSegments trait.
60 SmallVectorImpl<Value> &dynamicVec,
61 SmallVectorImpl<int64_t> &staticVec) {
62 auto v = llvm::dyn_cast_if_present<Value>(ofr);
63 if (!v) {
64 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
65 staticVec.push_back(apInt.getSExtValue());
66 return;
67 }
68 dynamicVec.push_back(v);
69 staticVec.push_back(ShapedType::kDynamic);
70}
71
72std::pair<int64_t, OpFoldResult>
74 int64_t tileSizeForShape =
75 getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
76
77 OpFoldResult tileSizeOfrSimplified =
78 (tileSizeForShape != ShapedType::kDynamic)
79 ? b.getIndexAttr(tileSizeForShape)
80 : tileSizeOfr;
81
82 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
83 tileSizeOfrSimplified);
84}
85
87 SmallVectorImpl<Value> &dynamicVec,
88 SmallVectorImpl<int64_t> &staticVec) {
89 for (OpFoldResult ofr : ofrs)
90 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
91}
92
93/// Given a value, try to extract a constant Attribute. If this fails, return
94/// the original value.
96 if (!val)
97 return OpFoldResult();
98 Attribute attr;
99 if (matchPattern(val, m_Constant(&attr)))
100 return attr;
101 return val;
102}
103
104/// Given an array of values, try to extract a constant Attribute from each
105/// value. If this fails, return the original value.
107 return llvm::map_to_vector(values,
108 [](Value v) { return getAsOpFoldResult(v); });
109}
110
111/// Convert `arrayAttr` to a vector of OpFoldResult.
114 res.reserve(arrayAttr.size());
115 for (Attribute a : arrayAttr)
116 res.push_back(a);
117 return res;
118}
119
121 return IntegerAttr::get(IndexType::get(ctx), val);
122}
123
125 ArrayRef<int64_t> values) {
126 return llvm::map_to_vector(
127 values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); });
128}
129
130/// If ofr is a constant integer or an IntegerAttr, return the integer.
131/// The boolean indicates whether the value is an index type.
132std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
133 // Case 1: Check for Constant integer.
134 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
135 APInt intVal;
136 if (matchPattern(val, m_ConstantInt(&intVal)))
137 return std::make_pair(intVal, val.getType().isIndex());
138 return std::nullopt;
139 }
140 // Case 2: Check for IntegerAttr.
141 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
142 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
143 return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
144 return std::nullopt;
145}
146
147/// If ofr is a constant integer or an IntegerAttr, return the integer.
148std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
149 std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
150 if (!apInt)
151 return std::nullopt;
152 return apInt->first.getSExtValue();
153}
154
155std::optional<SmallVector<int64_t>>
157 bool failed = false;
158 SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
159 auto cv = getConstantIntValue(ofr);
160 if (!cv.has_value())
161 failed = true;
162 return cv.value_or(0);
163 });
164 if (failed)
165 return std::nullopt;
166 return res;
167}
168
170 return getConstantIntValue(ofr) == value;
171}
172
174 return llvm::all_of(
175 ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
176}
177
179 ArrayRef<int64_t> values) {
180 if (ofrs.size() != values.size())
181 return false;
182 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
183 return constOfrs && llvm::equal(constOfrs.value(), values);
184}
185
186/// Return true if ofr1 and ofr2 are the same integer constant attribute values
187/// or the same SSA value.
188/// Ignore integer bitwidth and type mismatch that come from the fact there is
189/// no IndexAttr and that IndexType has no bitwidth.
191 auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
192 if (cst1 && cst2 && *cst1 == *cst2)
193 return true;
194 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
195 v2 = llvm::dyn_cast_if_present<Value>(ofr2);
196 return v1 && v1 == v2;
197}
198
201 if (ofrs1.size() != ofrs2.size())
202 return false;
203 for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
204 if (!isEqualConstantIntOrValue(ofr1, ofr2))
205 return false;
206 return true;
207}
208
209/// Return a vector of OpFoldResults with the same size as staticValues, but all
210/// elements for which ShapedType::isDynamic is true, will be replaced by
211/// dynamicValues.
213 ValueRange dynamicValues,
214 MLIRContext *context) {
215 assert(dynamicValues.size() == static_cast<size_t>(llvm::count_if(
216 staticValues, ShapedType::isDynamic)) &&
217 "expected the rank of dynamic values to match the number of "
218 "values known to be dynamic");
220 res.reserve(staticValues.size());
221 unsigned numDynamic = 0;
222 unsigned count = static_cast<unsigned>(staticValues.size());
223 for (unsigned idx = 0; idx < count; ++idx) {
224 int64_t value = staticValues[idx];
225 res.push_back(ShapedType::isDynamic(value)
226 ? OpFoldResult{dynamicValues[numDynamic++]}
227 : OpFoldResult{IntegerAttr::get(
228 IntegerType::get(context, 64), staticValues[idx])});
229 }
230 return res;
231}
233 ValueRange dynamicValues, Builder &b) {
234 return getMixedValues(staticValues, dynamicValues, b.getContext());
235}
236
237/// Decompose a vector of mixed static or dynamic values into the corresponding
238/// pair of arrays. This is the inverse function of `getMixedValues`.
239std::pair<SmallVector<int64_t>, SmallVector<Value>>
241 SmallVector<int64_t> staticValues;
242 SmallVector<Value> dynamicValues;
243 for (const auto &it : mixedValues) {
244 if (auto attr = dyn_cast<Attribute>(it)) {
245 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
246 } else {
247 staticValues.push_back(ShapedType::kDynamic);
248 dynamicValues.push_back(cast<Value>(it));
249 }
250 }
251 return {staticValues, dynamicValues};
252}
253
254/// Helper to sort `values` according to matching `keys`.
255template <typename K, typename V>
256static SmallVector<V>
258 llvm::function_ref<bool(K, K)> compare) {
259 if (keys.empty())
260 return SmallVector<V>{values};
261 assert(keys.size() == values.size() && "unexpected mismatching sizes");
262 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
263 llvm::sort(indices,
264 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
265 SmallVector<V> res;
266 res.reserve(values.size());
267 for (int64_t i = 0, e = indices.size(); i < e; ++i)
268 res.push_back(values[indices[i]]);
269 return res;
270}
271
274 llvm::function_ref<bool(Attribute, Attribute)> compare) {
275 return getValuesSortedByKeyImpl(keys, values, compare);
276}
277
280 llvm::function_ref<bool(Attribute, Attribute)> compare) {
281 return getValuesSortedByKeyImpl(keys, values, compare);
282}
283
286 llvm::function_ref<bool(Attribute, Attribute)> compare) {
287 return getValuesSortedByKeyImpl(keys, values, compare);
288}
289
290/// Return the number of iterations for a loop with a lower bound `lb`, upper
291/// bound `ub` and step `step`.
292std::optional<APInt> constantTripCount(
293 OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
294 llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
295 computeUbMinusLb) {
296 // This is the bitwidth used to return 0 when loop does not execute.
297 // We infer it from the type of the bound if it isn't an index type.
298 auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
299 if (auto intAttr =
300 dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
301 if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
302 return std::make_tuple(intType.getWidth(), intType.isIndex());
303 } else {
304 auto val = cast<Value>(ofr);
305 if (auto intType = dyn_cast<IntegerType>(val.getType()))
306 return std::make_tuple(intType.getWidth(), intType.isIndex());
307 }
308 return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
309 };
310 auto [bitwidth, isIndex] = getBitwidth(lb);
311 // This would better be an assert, but unfortunately it breaks scf.for_all
312 // which is missing attributes and SSA value optionally for its bounds, and
313 // uses Index type for the dynamic bounds but i64 for the static bounds. This
314 // is broken...
315 if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
316 LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
317 << lb;
318 return std::nullopt;
319 }
320 if (lb == ub) {
321 // Fast path: LB == UB. The loop has zero iterations.
322 // Note: LB and UB could match at runtime, even though they are different
323 // SSA values. That case cannot be detected here.
324 return APInt(bitwidth, 0);
325 }
326
327 std::optional<std::pair<APInt, bool>> maybeStepCst =
329
330 if (maybeStepCst) {
331 auto &stepCst = maybeStepCst->first;
332 assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
333 "step must have the same bitwidth as lb and ub");
334 if (stepCst.isZero()) {
335 // Step is zero. If LB and UB match, we have zero iterations. Otherwise,
336 // we have an infinite number of iterations. We cannot tell for sure which
337 // case applies, so the static trip count is unknown.
338 return std::nullopt;
339 }
340 }
341
342 if (isIndex) {
343 LDBG()
344 << "Computing loop trip count for index type may break with overflow";
345 // TODO: we can't compute the trip count for index type. We should fix this
346 // but too many tests are failing right now.
347 // return {};
348 }
349
350 /// Compute the difference between the upper and lower bound: either from the
351 /// constant value or using the computeUbMinusLb callback.
352 llvm::APSInt diff;
353 std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
354 std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
355 if (maybeLbCst) {
356 // If one of the bounds is not a constant, we can't compute the trip count.
357 if (!maybeUbCst)
358 return std::nullopt;
359 APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
360 APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
361 if (ubCst <= lbCst) {
362 LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
363 << lbCst.getBitWidth() << ") <= " << ubCst << "("
364 << ubCst.getBitWidth() << "), "
365 << (isSigned ? "isSigned" : "isUnsigned") << ")";
366 return APInt(bitwidth, 0);
367 }
368 diff = ubCst - lbCst;
369 } else {
370 if (maybeUbCst)
371 return std::nullopt;
372
373 /// Non-constant bound, let's try to compute the difference between the
374 /// upper and lower bound
375 std::optional<llvm::APSInt> maybeDiff =
376 computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
377 if (!maybeDiff)
378 return std::nullopt;
379 diff = *maybeDiff;
380 }
381 LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
382 << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
383 if (diff.isNegative()) {
384 LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
385 return APInt(bitwidth, 0);
386 }
387 if (!maybeStepCst) {
388 LDBG()
389 << "constantTripCount can't be computed because step is not a constant";
390 return std::nullopt;
391 }
392 auto &stepCst = maybeStepCst->first;
393 // For signed loops, a negative step size could indicate an infinite number of
394 // iterations.
395 if (isSigned && stepCst.isSignBitSet()) {
396 LDBG() << "constantTripCount is infinite because step is negative";
397 return std::nullopt;
398 }
399
400 // Create new APSInt instances with explicit signedness to ensure they match
401 llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst);
402 llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst);
403 if (!remainder.isZero())
404 tripCount = tripCount + 1;
405 LDBG() << "constantTripCount found: " << tripCount;
406 return tripCount;
407}
408
410 return llvm::none_of(sizesOrOffsets, [](int64_t value) {
411 return ShapedType::isStatic(value) && value < 0;
412 });
413}
414
416 return llvm::none_of(strides, [](int64_t value) {
417 return ShapedType::isStatic(value) && value == 0;
418 });
419}
420
422 bool onlyNonNegative, bool onlyNonZero) {
423 bool valuesChanged = false;
424 for (OpFoldResult &ofr : ofrs) {
425 if (isa<Attribute>(ofr))
426 continue;
427 Attribute attr;
428 if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
429 // Note: All ofrs have index type.
430 if (onlyNonNegative && *getConstantIntValue(attr) < 0)
431 continue;
432 if (onlyNonZero && *getConstantIntValue(attr) == 0)
433 continue;
434 ofr = attr;
435 valuesChanged = true;
436 }
437 }
438 return success(valuesChanged);
439}
440
441LogicalResult
443 return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
444 /*onlyNonZero=*/false);
445}
446
448 return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
449 /*onlyNonZero=*/true);
450}
451
452} // namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
bool isZeroFloat(OpFoldResult v)
Return "true" if v is a float value/attribute with constant value 0.0.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.