13#include "llvm/ADT/APSInt.h"
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/ADT/SmallVectorExtras.h"
16#include "llvm/Support/DebugLog.h"
17#include "llvm/Support/MathExtras.h"
24 if (
auto attr = dyn_cast<Attribute>(v)) {
25 if (
auto floatAttr = dyn_cast<FloatAttr>(attr))
26 return floatAttr.getValue().isZero();
38std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
39 SmallVector<OpFoldResult>>
42 offsets.reserve(ranges.size());
43 sizes.reserve(ranges.size());
44 strides.reserve(ranges.size());
45 for (
const auto &[offset, size, stride] : ranges) {
46 offsets.push_back(offset);
47 sizes.push_back(size);
48 strides.push_back(stride);
50 return std::make_tuple(offsets, sizes, strides);
62 auto v = llvm::dyn_cast_if_present<Value>(ofr);
64 APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
65 staticVec.push_back(apInt.getSExtValue());
68 dynamicVec.push_back(v);
69 staticVec.push_back(ShapedType::kDynamic);
72std::pair<int64_t, OpFoldResult>
78 (tileSizeForShape != ShapedType::kDynamic)
79 ?
b.getIndexAttr(tileSizeForShape)
82 return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
83 tileSizeOfrSimplified);
107 return llvm::map_to_vector(values,
114 res.reserve(arrayAttr.size());
121 return IntegerAttr::get(IndexType::get(ctx), val);
126 return llvm::map_to_vector(
134 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
137 return std::make_pair(intVal, val.getType().isIndex());
141 Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
142 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
143 return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
152 return apInt->first.getSExtValue();
155std::optional<SmallVector<int64_t>>
158 res.reserve(ofrs.size());
163 res.push_back(cv.value());
179 if (ofrs.size() != values.size())
182 return constOfrs && llvm::equal(constOfrs.value(), values);
191 if (cst1 && cst2 && *cst1 == *cst2)
193 auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
194 v2 = llvm::dyn_cast_if_present<Value>(ofr2);
195 return v1 && v1 == v2;
200 if (ofrs1.size() != ofrs2.size())
202 for (
auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
214 assert(dynamicValues.size() ==
static_cast<size_t>(llvm::count_if(
215 staticValues, ShapedType::isDynamic)) &&
216 "expected the rank of dynamic values to match the number of "
217 "values known to be dynamic");
219 res.reserve(staticValues.size());
220 unsigned numDynamic = 0;
221 unsigned count =
static_cast<unsigned>(staticValues.size());
222 for (
unsigned idx = 0; idx < count; ++idx) {
223 int64_t value = staticValues[idx];
224 res.push_back(ShapedType::isDynamic(value)
227 IntegerType::get(context, 64), staticValues[idx])});
238std::pair<SmallVector<int64_t>, SmallVector<Value>>
242 for (
const auto &it : mixedValues) {
243 if (
auto attr = dyn_cast<Attribute>(it)) {
244 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
246 staticValues.push_back(ShapedType::kDynamic);
247 dynamicValues.push_back(cast<Value>(it));
250 return {staticValues, dynamicValues};
254template <
typename K,
typename V>
260 assert(keys.size() == values.size() &&
"unexpected mismatching sizes");
261 auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
265 res.reserve(values.size());
267 res.push_back(values[
indices[i]]);
277SmallVector<OpFoldResult>
297 auto getBitwidth = [&](
OpFoldResult ofr) -> std::tuple<int, bool> {
299 dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
300 if (
auto intType = dyn_cast<IntegerType>(intAttr.getType()))
301 return std::make_tuple(intType.getWidth(), intType.isIndex());
303 auto val = cast<Value>(ofr);
304 if (
auto intType = dyn_cast<IntegerType>(val.getType()))
305 return std::make_tuple(intType.getWidth(), intType.isIndex());
307 return std::make_tuple(IndexType::kInternalStorageBitWidth,
true);
309 auto [bitwidth, isIndex] = getBitwidth(lb);
314 if (std::tie(bitwidth, isIndex) != getBitwidth(
ub)) {
315 LDBG() <<
"mismatch between lb and ub bitwidth/type: " <<
ub <<
" vs "
323 return APInt(bitwidth, 0);
326 std::optional<std::pair<APInt, bool>> maybeStepCst =
330 auto &stepCst = maybeStepCst->first;
331 assert(
static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
332 "step must have the same bitwidth as lb and ub");
333 if (stepCst.isZero()) {
343 <<
"Computing loop trip count for index type may break with overflow";
358 APSInt lbCst(maybeLbCst->first, !isSigned);
359 APSInt ubCst(maybeUbCst->first, !isSigned);
360 if (ubCst <= lbCst) {
361 LDBG() <<
"constantTripCount is 0 because ub <= lb (" << lbCst <<
"("
362 << lbCst.getBitWidth() <<
") <= " << ubCst <<
"("
363 << ubCst.getBitWidth() <<
"), "
364 << (isSigned ?
"isSigned" :
"isUnsigned") <<
")";
365 return APInt(bitwidth, 0);
369 diff = ubCst - lbCst;
373 diff.setIsUnsigned(
true);
380 std::optional<llvm::APSInt> maybeDiff =
381 computeUbMinusLb(cast<Value>(lb), cast<Value>(
ub), isSigned);
386 LDBG() <<
"constantTripCount: " << (isSigned ?
"isSigned" :
"isUnsigned")
387 <<
", ub-lb: " << diff <<
"(" << diff.getBitWidth() <<
"b)";
388 if (diff.isNegative()) {
389 LDBG() <<
"constantTripCount is 0 because ub-lb diff is negative";
390 return APInt(bitwidth, 0);
394 <<
"constantTripCount can't be computed because step is not a constant";
397 auto &stepCst = maybeStepCst->first;
400 if (isSigned && stepCst.isSignBitSet()) {
401 LDBG() <<
"constantTripCount is infinite because step is negative";
408 llvm::APInt tripCount = diff.udiv(stepCst);
409 llvm::APInt remainder = diff.urem(stepCst);
410 if (!remainder.isZero())
411 tripCount = tripCount + 1;
413 LDBG() <<
"constantTripCount found: " << tripCount;
418 return llvm::none_of(sizesOrOffsets, [](
int64_t value) {
419 return ShapedType::isStatic(value) && value < 0;
424 return llvm::none_of(strides, [](
int64_t value) {
425 return ShapedType::isStatic(value) && value == 0;
430 bool onlyNonNegative,
bool onlyNonZero) {
431 bool valuesChanged =
false;
433 if (isa<Attribute>(ofr))
443 valuesChanged =
true;
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
bool areConstantIntValues(ArrayRef< OpFoldResult > ofrs, ArrayRef< int64_t > values)
Return true if all of ofrs are constant integers equal to the corresponding value in values.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
static SmallVector< V > getValuesSortedByKeyImpl(ArrayRef< K > keys, ArrayRef< V > values, llvm::function_ref< bool(K, K)> compare)
Helper to sort values according to matching keys.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool isEqualConstantIntOrValueArray(ArrayRef< OpFoldResult > ofrs1, ArrayRef< OpFoldResult > ofrs2)
void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch an OpFoldResult into staticVec if: a) it is an IntegerAttr In other cases...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
bool isZeroFloat(OpFoldResult v)
Return "true" if v is a float value/attribute with constant value 0.0.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.