28 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
29 #define MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
41 namespace sparse_tensor {
63 #define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
71 #define MLIR_SPARSETENSOR_FOREVERY_O(DO) \
72 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
96 #define MLIR_SPARSETENSOR_FOREVERY_V(DO) \
109 #define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...) \
110 DO(F64, double, __VA_ARGS__) \
111 DO(F32, float, __VA_ARGS__) \
112 DO(F16, f16, __VA_ARGS__) \
113 DO(BF16, bf16, __VA_ARGS__) \
114 DO(I64, int64_t, __VA_ARGS__) \
115 DO(I32, int32_t, __VA_ARGS__) \
116 DO(I16, int16_t, __VA_ARGS__) \
117 DO(I8, int8_t, __VA_ARGS__) \
118 DO(C64, complex64, __VA_ARGS__) \
119 DO(C32, complex32, __VA_ARGS__)
122 #define MLIR_SPARSETENSOR_FOREVERY_V_O(DO) \
123 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t) \
124 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t) \
125 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t) \
126 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t) \
127 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type)
165 auto enc =
static_cast<std::underlying_type_t<LevelFormat>
>(fmt);
166 return (enc & (enc - 1)) == 0;
179 return (... || (targets == fmt));
196 return "loose_compressed";
242 auto fmt =
static_cast<LevelFormat>(lvlBits & 0xffff0000);
243 const uint64_t propertyBits = lvlBits & 0xffff;
248 ? (propertyBits == 0)
256 static std::optional<LevelType>
258 const std::vector<LevelPropNonDefault> &properties,
259 uint64_t n = 0, uint64_t m = 0) {
260 assert((n & 0xff) == n && (m & 0xff) == m);
261 uint64_t newN = n << 32;
262 uint64_t newM = m << 40;
263 uint64_t ltBits =
static_cast<uint64_t
>(lf) | newN | newM;
264 for (
auto p : properties)
265 ltBits |=
static_cast<uint64_t
>(p);
271 bool unique, uint64_t n = 0,
273 std::vector<LevelPropNonDefault> properties;
282 constexpr
explicit LevelType(uint64_t bits) : lvlBits(bits) {
292 explicit operator uint64_t()
const {
return lvlBits; }
295 return static_cast<uint64_t
>(lhs) == lvlBits;
302 constexpr uint64_t mask =
308 constexpr uint64_t
getN()
const {
309 assert(isa<LevelFormat::NOutOfM>());
310 return (lvlBits >> 32) & 0xff;
314 constexpr uint64_t
getM()
const {
315 assert(isa<LevelFormat::NOutOfM>());
316 return (lvlBits >> 40) & 0xff;
321 return static_cast<LevelFormat>(lvlBits & 0xffff0000);
326 constexpr
bool isa()
const {
327 return (... || (
getLvlFmt() == fmt)) ||
false;
331 template <LevelPropNonDefault p>
332 constexpr
bool isa()
const {
333 return lvlBits &
static_cast<uint64_t
>(p);
344 return isa<LevelFormat::Dense, LevelFormat::Batch>();
349 assert(!isa<LevelFormat::Undef>());
350 return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
355 assert(!isa<LevelFormat::Undef>());
366 std::string propStr =
"";
367 if (isa<LevelFormat::NOutOfM>()) {
369 "[" + std::to_string(
getN()) +
", " + std::to_string(
getM()) +
"]";
371 if (isa<LevelPropNonDefault::Nonunique>())
374 if (isa<LevelPropNonDefault::Nonordered>()) {
375 if (!propStr.empty())
379 if (isa<LevelPropNonDefault::SoA>()) {
380 if (!propStr.empty())
384 if (!propStr.empty())
385 lvlStr += (
"(" + propStr +
")");
398 constexpr uint64_t
nToBits(uint64_t n) {
return n << 32; }
399 constexpr uint64_t
mToBits(uint64_t m) {
return m << 40; }
401 inline std::optional<LevelType>
403 const std::vector<LevelPropNonDefault> &properties,
404 uint64_t n = 0, uint64_t m = 0) {
408 bool unique, uint64_t n = 0,
471 constexpr uint64_t
encodeDim(uint64_t i, uint64_t cf, uint64_t cm) {
473 assert(cf <= 0xfffffu && cm == 0 && i <= 0xfffffu);
474 return (
static_cast<uint64_t
>(0x01u) << 60) | (cf << 20) | i;
477 assert(cm <= 0xfffffu && i <= 0xfffffu);
478 return (
static_cast<uint64_t
>(0x02u) << 60) | (cm << 20) | i;
480 assert(i <= 0x0fffffffffffffffu);
483 constexpr uint64_t
encodeLvl(uint64_t i, uint64_t c, uint64_t ii) {
485 assert(c <= 0xfffffu && ii <= 0xfffffu && i <= 0xfffffu);
486 return (
static_cast<uint64_t
>(0x03u) << 60) | (c << 20) | (ii << 40) | i;
488 assert(i <= 0x0fffffffffffffffu);
494 constexpr uint64_t
decodeIndex(uint64_t v) {
return v & 0xfffffu; }
495 constexpr uint64_t
decodeConst(uint64_t v) {
return (v >> 20) & 0xfffffu; }
496 constexpr uint64_t
decodeMulc(uint64_t v) {
return (v >> 20) & 0xfffffu; }
497 constexpr uint64_t
decodeMuli(uint64_t v) {
return (v >> 40) & 0xfffffu; }
std::complex< double > complex64
bool isUniqueLT(LevelType lt)
LevelFormat
This enum defines all supported storage format without the level properties.
bool isWithCrdLT(LevelType lt)
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
constexpr uint64_t decodeMuli(uint64_t v)
OverheadType
Encoding of overhead types (both position overhead and coordinate overhead), for "overloading" @newSp...
bool isSingletonLT(LevelType lt)
Action
The actions performed by @newSparseTensor.
bool isCompressedLT(LevelType lt)
constexpr bool isAnyOfFmt(LevelFormat fmt)
constexpr bool isEncodedMul(uint64_t v)
bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m)
constexpr uint64_t mToBits(uint64_t m)
constexpr bool isIntegralPrimaryType(PrimaryType valTy)
constexpr bool isEncodedMod(uint64_t v)
constexpr uint64_t decodeConst(uint64_t v)
uint64_t getN(LevelType lt)
std::optional< LevelFormat > getLevelFormat(LevelType lt)
bool isLooseCompressedLT(LevelType lt)
PrimaryType
Encoding of the elemental type, for "overloading" @newSparseTensor.
std::complex< float > complex32
constexpr uint64_t decodeIndex(uint64_t v)
constexpr bool isRealPrimaryType(PrimaryType valTy)
constexpr const char * toFormatString(LevelFormat lvlFmt)
Returns string representation of the given level format.
constexpr uint64_t nToBits(uint64_t n)
uint64_t index_type
This type is used in the public API at all places where MLIR expects values with the built-in type "i...
bool isUndefLT(LevelType lt)
constexpr bool isComplexPrimaryType(PrimaryType valTy)
bool isDenseLT(LevelType lt)
constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii)
bool isValidLT(LevelType lt)
constexpr uint64_t decodeMulc(uint64_t v)
uint64_t getM(LevelType lt)
constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm)
Bit manipulations for affine encoding.
constexpr const char * toPropString(LevelPropNonDefault lvlProp)
Returns string representation of the given level properties.
constexpr bool isFloatingPrimaryType(PrimaryType valTy)
constexpr bool encPowOfTwo(LevelFormat fmt)
bool isBatchLT(LevelType lt)
LevelPropNonDefault
This enum defines all the nondefault properties for storage formats.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
constexpr bool isEncodedFloor(uint64_t v)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool hasSparseSemantic() const
Check if the LevelType is considered to be sparse.
constexpr unsigned getNumBuffer() const
static std::optional< LevelType > buildLvlType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Convert a LevelFormat to its corresponding LevelType with the given properties.
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
constexpr uint64_t getM() const
Get M of NOutOfM level type.
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
bool operator==(const LevelType lhs) const
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
constexpr LevelType(uint64_t bits)
Explicit conversion from uint64_t.
static std::optional< LevelType > buildLvlType(LevelFormat lf, bool ordered, bool unique, uint64_t n=0, uint64_t m=0)
constexpr uint64_t getN() const
Get N of NOutOfM level type.
bool operator!=(const LevelType lhs) const
LevelType stripStorageIrrelevantProperties() const
static constexpr bool isValidLvlBits(uint64_t lvlBits)
Check that the LevelType contains a valid (possibly undefined) value.
std::string toMLIRString() const
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.
LevelType(LevelFormat f)
Constructs a LevelType with the given format using all default properties.