13 #include "llvm/ADT/APSInt.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Support/DebugLog.h"
16 #include "llvm/Support/MathExtras.h"
24 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
25 SmallVector<OpFoldResult>>
28 offsets.reserve(ranges.size());
29 sizes.reserve(ranges.size());
30 strides.reserve(ranges.size());
31 for (
const auto &[offset, size, stride] : ranges) {
32 offsets.push_back(offset);
33 sizes.push_back(size);
34 strides.push_back(stride);
36 return std::make_tuple(offsets, sizes, strides);
48 auto v = llvm::dyn_cast_if_present<Value>(ofr);
50 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
51 staticVec.push_back(apInt.getSExtValue());
54 dynamicVec.push_back(v);
55 staticVec.push_back(ShapedType::kDynamic);
58 std::pair<int64_t, OpFoldResult>
60 int64_t tileSizeForShape =
64 (tileSizeForShape != ShapedType::kDynamic)
68 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
69 tileSizeOfrSimplified);
93 return llvm::to_vector(
100 res.reserve(arrayAttr.size());
112 return llvm::to_vector(llvm::map_range(
120 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
123 return std::make_pair(intVal, val.getType().isIndex());
127 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
128 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
129 return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
138 return apInt->first.getSExtValue();
141 std::optional<SmallVector<int64_t>>
148 return cv.value_or(0);
166 if (ofrs.size() != values.size())
169 return constOfrs && llvm::equal(constOfrs.value(), values);
178 if (cst1 && cst2 && *cst1 == *cst2)
180 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
181 v2 = llvm::dyn_cast_if_present<Value>(ofr2);
182 return v1 && v1 == v2;
187 if (ofrs1.size() != ofrs2.size())
189 for (
auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
201 assert(dynamicValues.size() ==
static_cast<size_t>(llvm::count_if(
202 staticValues, ShapedType::isDynamic)) &&
203 "expected the rank of dynamic values to match the number of "
204 "values known to be dynamic");
206 res.reserve(staticValues.size());
207 unsigned numDynamic = 0;
208 unsigned count =
static_cast<unsigned>(staticValues.size());
209 for (
unsigned idx = 0; idx < count; ++idx) {
210 int64_t value = staticValues[idx];
211 res.push_back(ShapedType::isDynamic(value)
225 std::pair<SmallVector<int64_t>, SmallVector<Value>>
229 for (
const auto &it : mixedValues) {
230 if (
auto attr = dyn_cast<Attribute>(it)) {
231 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
233 staticValues.push_back(ShapedType::kDynamic);
234 dynamicValues.push_back(cast<Value>(it));
237 return {staticValues, dynamicValues};
241 template <
typename K,
typename V>
242 static SmallVector<V>
247 assert(keys.size() == values.size() &&
"unexpected mismatching sizes");
248 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
250 [&](int64_t i, int64_t
j) {
return compare(keys[i], keys[
j]); });
252 res.reserve(values.size());
253 for (int64_t i = 0, e = indices.size(); i < e; ++i)
254 res.push_back(values[indices[i]]);
264 SmallVector<OpFoldResult>
284 auto getBitwidth = [&](
OpFoldResult ofr) -> std::tuple<int, bool> {
286 dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
287 if (
auto intType = dyn_cast<IntegerType>(intAttr.getType()))
288 return std::make_tuple(intType.getWidth(), intType.isIndex());
290 auto val = cast<Value>(ofr);
291 if (
auto intType = dyn_cast<IntegerType>(val.getType()))
292 return std::make_tuple(intType.getWidth(), intType.isIndex());
294 return std::make_tuple(IndexType::kInternalStorageBitWidth,
true);
296 auto [bitwidth, isIndex] = getBitwidth(lb);
301 if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
302 LDBG() <<
"mismatch between lb and ub bitwidth/type: " << ub <<
" vs "
307 return APInt(bitwidth, 0);
309 std::optional<std::pair<APInt, bool>> maybeStepCst =
313 auto &stepCst = maybeStepCst->first;
314 assert(
static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
315 "step must have the same bitwidth as lb and ub");
316 if (stepCst.isZero())
318 if (stepCst.isNegative())
319 return APInt(bitwidth, 0);
324 <<
"Computing loop trip count for index type may break with overflow";
339 APSInt lbCst(maybeLbCst->first, !isSigned);
340 APSInt ubCst(maybeUbCst->first, !isSigned);
343 if (ubCst <= lbCst) {
344 LDBG() <<
"constantTripCount is 0 because ub <= lb (" << lbCst <<
"("
345 << lbCst.getBitWidth() <<
") <= " << ubCst <<
"("
346 << ubCst.getBitWidth() <<
"), "
347 << (isSigned ?
"isSigned" :
"isUnsigned") <<
")";
348 return APInt(bitwidth, 0);
350 diff = ubCst - lbCst;
357 std::optional<llvm::APSInt> maybeDiff =
363 LDBG() <<
"constantTripCount: " << (isSigned ?
"isSigned" :
"isUnsigned")
364 <<
", ub-lb: " << diff <<
"(" << diff.getBitWidth() <<
"b)";
365 if (diff.isNegative()) {
366 LDBG() <<
"constantTripCount is 0 because ub-lb diff is negative";
367 return APInt(bitwidth, 0);
371 <<
"constantTripCount can't be computed because step is not a constant";
374 auto &stepCst = maybeStepCst->first;
375 llvm::APInt tripCount = diff.sdiv(stepCst);
376 llvm::APInt r = diff.srem(stepCst);
378 tripCount = tripCount + 1;
379 LDBG() <<
"constantTripCount found: " << tripCount;
384 return llvm::none_of(sizesOrOffsets, [](int64_t value) {
385 return ShapedType::isStatic(value) && value < 0;
390 return llvm::none_of(strides, [](int64_t value) {
391 return ShapedType::isStatic(value) && value == 0;
396 bool onlyNonNegative,
bool onlyNonZero) {
397 bool valuesChanged =
false;
399 if (isa<Attribute>(ofr))
409 valuesChanged =
true;
412 return success(valuesChanged);
static std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.