MLIR 22.0.0git
StaticValueUtils.cpp
Go to the documentation of this file.
1//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/Support/LLVM.h"
13#include "llvm/ADT/APSInt.h"
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/Support/DebugLog.h"
16#include "llvm/Support/MathExtras.h"
17
18namespace mlir {
19
21
23 if (auto attr = dyn_cast<Attribute>(v)) {
24 if (auto floatAttr = dyn_cast<FloatAttr>(attr))
25 return floatAttr.getValue().isZero();
26 return false;
27 }
28 return matchPattern(cast<Value>(v), m_AnyZeroFloat());
29}
30
34
36
37std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
40 SmallVector<OpFoldResult> offsets, sizes, strides;
41 offsets.reserve(ranges.size());
42 sizes.reserve(ranges.size());
43 strides.reserve(ranges.size());
44 for (const auto &[offset, size, stride] : ranges) {
45 offsets.push_back(offset);
46 sizes.push_back(size);
47 strides.push_back(stride);
48 }
49 return std::make_tuple(offsets, sizes, strides);
50}
51
52/// Helper function to dispatch an OpFoldResult into `staticVec` if:
53/// a) it is an IntegerAttr
54/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
55/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
56/// `staticVec`. This is useful to extract mixed static and dynamic entries that
57/// come from an AttrSizedOperandSegments trait.
59 SmallVectorImpl<Value> &dynamicVec,
60 SmallVectorImpl<int64_t> &staticVec) {
61 auto v = llvm::dyn_cast_if_present<Value>(ofr);
62 if (!v) {
63 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
64 staticVec.push_back(apInt.getSExtValue());
65 return;
66 }
67 dynamicVec.push_back(v);
68 staticVec.push_back(ShapedType::kDynamic);
69}
70
71std::pair<int64_t, OpFoldResult>
73 int64_t tileSizeForShape =
74 getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
75
76 OpFoldResult tileSizeOfrSimplified =
77 (tileSizeForShape != ShapedType::kDynamic)
78 ? b.getIndexAttr(tileSizeForShape)
79 : tileSizeOfr;
80
81 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
82 tileSizeOfrSimplified);
83}
84
86 SmallVectorImpl<Value> &dynamicVec,
87 SmallVectorImpl<int64_t> &staticVec) {
88 for (OpFoldResult ofr : ofrs)
89 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
90}
91
92/// Given a value, try to extract a constant Attribute. If this fails, return
93/// the original value.
95 if (!val)
96 return OpFoldResult();
97 Attribute attr;
98 if (matchPattern(val, m_Constant(&attr)))
99 return attr;
100 return val;
101}
102
103/// Given an array of values, try to extract a constant Attribute from each
104/// value. If this fails, return the original value.
106 return llvm::to_vector(
107 llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
108}
109
110/// Convert `arrayAttr` to a vector of OpFoldResult.
113 res.reserve(arrayAttr.size());
114 for (Attribute a : arrayAttr)
115 res.push_back(a);
116 return res;
117}
118
120 return IntegerAttr::get(IndexType::get(ctx), val);
121}
122
124 ArrayRef<int64_t> values) {
125 return llvm::to_vector(llvm::map_range(
126 values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
127}
128
129/// If ofr is a constant integer or an IntegerAttr, return the integer.
130/// The boolean indicates whether the value is an index type.
131std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
132 // Case 1: Check for Constant integer.
133 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
134 APInt intVal;
135 if (matchPattern(val, m_ConstantInt(&intVal)))
136 return std::make_pair(intVal, val.getType().isIndex());
137 return std::nullopt;
138 }
139 // Case 2: Check for IntegerAttr.
140 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
141 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
142 return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
143 return std::nullopt;
144}
145
146/// If ofr is a constant integer or an IntegerAttr, return the integer.
147std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
148 std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
149 if (!apInt)
150 return std::nullopt;
151 return apInt->first.getSExtValue();
152}
153
154std::optional<SmallVector<int64_t>>
156 bool failed = false;
157 SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
158 auto cv = getConstantIntValue(ofr);
159 if (!cv.has_value())
160 failed = true;
161 return cv.value_or(0);
162 });
163 if (failed)
164 return std::nullopt;
165 return res;
166}
167
169 return getConstantIntValue(ofr) == value;
170}
171
173 return llvm::all_of(
174 ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
175}
176
178 ArrayRef<int64_t> values) {
179 if (ofrs.size() != values.size())
180 return false;
181 std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
182 return constOfrs && llvm::equal(constOfrs.value(), values);
183}
184
185/// Return true if ofr1 and ofr2 are the same integer constant attribute values
186/// or the same SSA value.
187/// Ignore integer bitwidth and type mismatch that come from the fact there is
188/// no IndexAttr and that IndexType has no bitwidth.
190 auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
191 if (cst1 && cst2 && *cst1 == *cst2)
192 return true;
193 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
194 v2 = llvm::dyn_cast_if_present<Value>(ofr2);
195 return v1 && v1 == v2;
196}
197
200 if (ofrs1.size() != ofrs2.size())
201 return false;
202 for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
203 if (!isEqualConstantIntOrValue(ofr1, ofr2))
204 return false;
205 return true;
206}
207
208/// Return a vector of OpFoldResults with the same size as staticValues, but all
209/// elements for which ShapedType::isDynamic is true, will be replaced by
210/// dynamicValues.
212 ValueRange dynamicValues,
213 MLIRContext *context) {
214 assert(dynamicValues.size() == static_cast<size_t>(llvm::count_if(
215 staticValues, ShapedType::isDynamic)) &&
216 "expected the rank of dynamic values to match the number of "
217 "values known to be dynamic");
219 res.reserve(staticValues.size());
220 unsigned numDynamic = 0;
221 unsigned count = static_cast<unsigned>(staticValues.size());
222 for (unsigned idx = 0; idx < count; ++idx) {
223 int64_t value = staticValues[idx];
224 res.push_back(ShapedType::isDynamic(value)
225 ? OpFoldResult{dynamicValues[numDynamic++]}
226 : OpFoldResult{IntegerAttr::get(
227 IntegerType::get(context, 64), staticValues[idx])});
228 }
229 return res;
230}
232 ValueRange dynamicValues, Builder &b) {
233 return getMixedValues(staticValues, dynamicValues, b.getContext());
234}
235
236/// Decompose a vector of mixed static or dynamic values into the corresponding
237/// pair of arrays. This is the inverse function of `getMixedValues`.
238std::pair<SmallVector<int64_t>, SmallVector<Value>>
240 SmallVector<int64_t> staticValues;
241 SmallVector<Value> dynamicValues;
242 for (const auto &it : mixedValues) {
243 if (auto attr = dyn_cast<Attribute>(it)) {
244 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
245 } else {
246 staticValues.push_back(ShapedType::kDynamic);
247 dynamicValues.push_back(cast<Value>(it));
248 }
249 }
250 return {staticValues, dynamicValues};
251}
252
253/// Helper to sort `values` according to matching `keys`.
254template <typename K, typename V>
255static SmallVector<V>
257 llvm::function_ref<bool(K, K)> compare) {
258 if (keys.empty())
259 return SmallVector<V>{values};
260 assert(keys.size() == values.size() && "unexpected mismatching sizes");
261 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
262 llvm::sort(indices,
263 [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
264 SmallVector<V> res;
265 res.reserve(values.size());
266 for (int64_t i = 0, e = indices.size(); i < e; ++i)
267 res.push_back(values[indices[i]]);
268 return res;
269}
270
273 llvm::function_ref<bool(Attribute, Attribute)> compare) {
274 return getValuesSortedByKeyImpl(keys, values, compare);
275}
276
279 llvm::function_ref<bool(Attribute, Attribute)> compare) {
280 return getValuesSortedByKeyImpl(keys, values, compare);
281}
282
285 llvm::function_ref<bool(Attribute, Attribute)> compare) {
286 return getValuesSortedByKeyImpl(keys, values, compare);
287}
288
289/// Return the number of iterations for a loop with a lower bound `lb`, upper
290/// bound `ub` and step `step`.
291std::optional<APInt> constantTripCount(
292 OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
293 llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
294 computeUbMinusLb) {
295 // This is the bitwidth used to return 0 when loop does not execute.
296 // We infer it from the type of the bound if it isn't an index type.
297 auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
298 if (auto intAttr =
299 dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
300 if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
301 return std::make_tuple(intType.getWidth(), intType.isIndex());
302 } else {
303 auto val = cast<Value>(ofr);
304 if (auto intType = dyn_cast<IntegerType>(val.getType()))
305 return std::make_tuple(intType.getWidth(), intType.isIndex());
306 }
307 return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
308 };
309 auto [bitwidth, isIndex] = getBitwidth(lb);
310 // This would better be an assert, but unfortunately it breaks scf.for_all
311 // which is missing attributes and SSA value optionally for its bounds, and
312 // uses Index type for the dynamic bounds but i64 for the static bounds. This
313 // is broken...
314 if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
315 LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
316 << lb;
317 return std::nullopt;
318 }
319 if (lb == ub)
320 return APInt(bitwidth, 0);
321
322 std::optional<std::pair<APInt, bool>> maybeStepCst =
324
325 if (maybeStepCst) {
326 auto &stepCst = maybeStepCst->first;
327 assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
328 "step must have the same bitwidth as lb and ub");
329 if (stepCst.isZero())
330 return stepCst;
331 if (stepCst.isNegative())
332 return APInt(bitwidth, 0);
333 }
334
335 if (isIndex) {
336 LDBG()
337 << "Computing loop trip count for index type may break with overflow";
338 // TODO: we can't compute the trip count for index type. We should fix this
339 // but too many tests are failing right now.
340 // return {};
341 }
342
343 /// Compute the difference between the upper and lower bound: either from the
344 /// constant value or using the computeUbMinusLb callback.
345 llvm::APSInt diff;
346 std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
347 std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
348 if (maybeLbCst) {
349 // If one of the bounds is not a constant, we can't compute the trip count.
350 if (!maybeUbCst)
351 return std::nullopt;
352 APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
353 APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
354 if (ubCst <= lbCst) {
355 LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
356 << lbCst.getBitWidth() << ") <= " << ubCst << "("
357 << ubCst.getBitWidth() << "), "
358 << (isSigned ? "isSigned" : "isUnsigned") << ")";
359 return APInt(bitwidth, 0);
360 }
361 diff = ubCst - lbCst;
362 } else {
363 if (maybeUbCst)
364 return std::nullopt;
365
366 /// Non-constant bound, let's try to compute the difference between the
367 /// upper and lower bound
368 std::optional<llvm::APSInt> maybeDiff =
369 computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
370 if (!maybeDiff)
371 return std::nullopt;
372 diff = *maybeDiff;
373 }
374 LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
375 << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
376 if (diff.isNegative()) {
377 LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
378 return APInt(bitwidth, 0);
379 }
380 if (!maybeStepCst) {
381 LDBG()
382 << "constantTripCount can't be computed because step is not a constant";
383 return std::nullopt;
384 }
385 auto &stepCst = maybeStepCst->first;
386 llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst);
387 llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst);
388 if (!remainder.isZero())
389 tripCount = tripCount + 1;
390 LDBG() << "constantTripCount found: " << tripCount;
391 return tripCount;
392}
393
395 return llvm::none_of(sizesOrOffsets, [](int64_t value) {
396 return ShapedType::isStatic(value) && value < 0;
397 });
398}
399
401 return llvm::none_of(strides, [](int64_t value) {
402 return ShapedType::isStatic(value) && value == 0;
403 });
404}
405
407 bool onlyNonNegative, bool onlyNonZero) {
408 bool valuesChanged = false;
409 for (OpFoldResult &ofr : ofrs) {
410 if (isa<Attribute>(ofr))
411 continue;
412 Attribute attr;
413 if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
414 // Note: All ofrs have index type.
415 if (onlyNonNegative && *getConstantIntValue(attr) < 0)
416 continue;
417 if (onlyNonZero && *getConstantIntValue(attr) == 0)
418 continue;
419 ofr = attr;
420 valuesChanged = true;
421 }
422 }
423 return success(valuesChanged);
424}
425
426LogicalResult
428 return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
429 /*onlyNonZero=*/false);
430}
431
433 return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
434 /*onlyNonZero=*/true);
435}
436
437} // namespace mlir
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
bool isZeroFloat(OpFoldResult v)
Return "true" if v is a float value/attribute with constant value 0.0.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.