16 #ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_STORAGE_H
17 #define MLIR_EXECUTIONENGINE_SPARSETENSOR_STORAGE_H
26 namespace sparse_tensor {
71 uint64_t lvlRank,
const uint64_t *lvlSizes,
72 const LevelType *lvlTypes,
const uint64_t *dim2lvl,
73 const uint64_t *lvl2dim);
83 const std::vector<uint64_t> &
getDimSizes()
const {
return dimSizes; }
92 const std::vector<uint64_t> &
getLvlSizes()
const {
return lvlSizes; }
101 const std::vector<LevelType> &
getLvlTypes()
const {
return lvlTypes; }
135 #define DECL_GETPOSITIONS(PNAME, P) \
136 virtual void getPositions(std::vector<P> **, uint64_t);
138 #undef DECL_GETPOSITIONS
141 #define DECL_GETCOORDINATES(INAME, C) \
142 virtual void getCoordinates(std::vector<C> **, uint64_t);
144 #undef DECL_GETCOORDINATES
147 #define DECL_GETCOORDINATESBUFFER(INAME, C) \
148 virtual void getCoordinatesBuffer(std::vector<C> **, uint64_t);
150 #undef DECL_GETCOORDINATESBUFFER
153 #define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
155 #undef DECL_GETVALUES
159 #define DECL_LEXINSERT(VNAME, V) virtual void lexInsert(const uint64_t *, V);
161 #undef DECL_LEXINSERT
166 #define DECL_EXPINSERT(VNAME, V) \
167 virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t, \
170 #undef DECL_EXPINSERT
176 const std::vector<uint64_t> dimSizes;
177 const std::vector<uint64_t> lvlSizes;
178 const std::vector<LevelType> lvlTypes;
179 const std::vector<uint64_t> dim2lvlVec;
180 const std::vector<uint64_t> lvl2dimVec;
194 template <
typename P,
typename C,
typename V>
201 uint64_t lvlRank,
const uint64_t *lvlSizes,
202 const LevelType *lvlTypes,
const uint64_t *dim2lvl,
203 const uint64_t *lvl2dim)
206 positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank) {}
214 uint64_t lvlRank,
const uint64_t *lvlSizes,
215 const LevelType *lvlTypes,
const uint64_t *dim2lvl,
223 uint64_t lvlRank,
const uint64_t *lvlSizes,
224 const LevelType *lvlTypes,
const uint64_t *dim2lvl,
225 const uint64_t *lvl2dim,
const intptr_t *lvlBufs);
229 newEmpty(uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
230 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
231 const uint64_t *dim2lvl,
const uint64_t *lvl2dim);
235 newFromCOO(uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
236 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
237 const uint64_t *dim2lvl,
const uint64_t *lvl2dim,
242 newFromBuffers(uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
243 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
244 const uint64_t *dim2lvl,
const uint64_t *lvl2dim,
245 uint64_t srcRank,
const intptr_t *buffers);
251 assert(out &&
"Received nullptr for out parameter");
253 *out = &positions[lvl];
256 assert(out &&
"Received nullptr for out parameter");
258 *out = &coordinates[lvl];
261 assert(out &&
"Received nullptr for out parameter");
274 uint64_t nnz = values.size();
276 crdBuffer.reserve(nnz * (lvlRank - lvl));
277 for (uint64_t i = 0; i < nnz; i++) {
278 for (uint64_t l = lvl; l < lvlRank; l++) {
279 assert(i < coordinates[l].size());
280 crdBuffer.push_back(coordinates[l][i]);
286 assert(out &&
"Received nullptr for out parameter");
291 void lexInsert(
const uint64_t *lvlCoords, V val)
final {
297 for (uint64_t l = 0; l < lvlRank; l++)
298 valIdx = valIdx *
getLvlSize(l) + lvlCoords[l];
299 values[valIdx] = val;
303 uint64_t diffLvl = 0;
305 if (!values.empty()) {
306 diffLvl = lexDiff(lvlCoords);
307 endPath(diffLvl + 1);
308 full = lvlCursor[diffLvl] + 1;
311 insPath(lvlCoords, diffLvl, full, val);
315 void expInsert(uint64_t *lvlCoords, V *values,
bool *filled, uint64_t *added,
316 uint64_t count, uint64_t expsz)
final {
317 assert((lvlCoords && values && filled && added) &&
"Received nullptr");
321 std::sort(added, added + count);
324 uint64_t c = added[0];
326 assert(filled[c] &&
"added coordinate is not filled");
327 lvlCoords[lastLvl] = c;
332 for (uint64_t i = 1; i < count; i++) {
333 assert(c < added[i] &&
"non-lexicographic insertion");
336 assert(filled[c] &&
"added coordinate is not filled");
337 lvlCoords[lastLvl] = c;
338 insPath(lvlCoords, lastLvl, added[i - 1] + 1, values[c]);
357 uint64_t nnz = values.size();
360 assert(nnz == coordinates[l].size());
364 auto applyPerm = [
this](std::vector<uint64_t> &perm) {
365 uint64_t length = perm.size();
368 std::vector<P> lvlCrds(lvlRank);
369 for (uint64_t i = 0; i < length; i++) {
370 uint64_t current = i;
371 if (i != perm[current]) {
372 for (uint64_t l = 0; l < lvlRank; l++)
373 lvlCrds[l] = coordinates[l][i];
376 while (i != perm[current]) {
377 uint64_t next = perm[current];
379 for (uint64_t l = 0; l < lvlRank; l++)
380 coordinates[l][current] = coordinates[l][next];
381 values[current] = values[next];
382 perm[current] = current;
385 for (uint64_t l = 0; l < lvlRank; l++)
386 coordinates[l][current] = lvlCrds[l];
387 values[current] = val;
388 perm[current] = current;
393 std::vector<uint64_t> sortedIdx(nnz, 0);
394 for (uint64_t i = 0; i < nnz; i++)
397 std::sort(sortedIdx.begin(), sortedIdx.end(),
398 [
this](uint64_t lhs, uint64_t rhs) {
399 for (uint64_t l = 0; l < getLvlRank(); l++) {
400 if (coordinates[l][lhs] == coordinates[l][rhs])
402 return coordinates[l][lhs] < coordinates[l][rhs];
404 assert(lhs == rhs &&
"duplicate coordinates");
408 applyPerm(sortedIdx);
421 void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
422 if (!isDenseLvl(lvl)) {
423 assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
424 isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
425 coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
427 assert(crd >= full &&
"Coordinate was already filled");
430 if (lvl + 1 == getLvlRank())
431 values.insert(values.end(), crd - full, 0);
433 finalizeSegment(lvl + 1, 0, crd - full);
439 uint64_t assembledSize(uint64_t parentSz, uint64_t l)
const {
440 if (isCompressedLvl(l))
441 return positions[l][parentSz];
442 if (isLooseCompressedLvl(l))
443 return positions[l][2 * parentSz - 1];
444 if (isSingletonLvl(l) || isNOutOfMLvl(l))
446 assert(isDenseLvl(l));
447 return parentSz * getLvlSize(l);
453 void fromCOO(
const std::vector<Element<V>> &lvlElements, uint64_t lo,
454 uint64_t hi, uint64_t l) {
455 const uint64_t lvlRank = getLvlRank();
456 assert(l <= lvlRank && hi <= lvlElements.size());
460 values.push_back(lvlElements[lo].value);
467 const uint64_t c = lvlElements[lo].coords[l];
468 uint64_t seg = lo + 1;
470 while (seg < hi && lvlElements[seg].coords[l] == c)
473 appendCrd(l, full, c);
475 fromCOO(lvlElements, lo, seg, l + 1);
480 finalizeSegment(l, full);
484 void finalizeSegment(uint64_t l, uint64_t full = 0, uint64_t count = 1) {
487 if (isCompressedLvl(l)) {
488 uint64_t pos = coordinates[l].size();
489 positions[l].insert(positions[l].end(), count,
490 detail::checkOverflowCast<P>(pos));
491 }
else if (isLooseCompressedLvl(l)) {
495 uint64_t pos = coordinates[l].size();
496 positions[l].insert(positions[l].end(), 2 * count,
497 detail::checkOverflowCast<P>(pos));
498 }
else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
501 assert(isDenseLvl(l));
502 const uint64_t sz = getLvlSizes()[l];
503 assert(sz >= full &&
"Segment is overfull");
509 if (l + 1 == getLvlRank())
510 values.insert(values.end(), count, 0);
512 finalizeSegment(l + 1, 0, count);
517 void endPath(uint64_t diffLvl) {
518 const uint64_t lvlRank = getLvlRank();
519 const uint64_t lastLvl = lvlRank - 1;
520 assert(diffLvl <= lvlRank);
521 const uint64_t stop = lvlRank - diffLvl;
522 for (uint64_t i = 0; i < stop; i++) {
523 const uint64_t l = lastLvl - i;
524 finalizeSegment(l, lvlCursor[l] + 1);
530 void insPath(
const uint64_t *lvlCoords, uint64_t diffLvl, uint64_t full,
532 const uint64_t lvlRank = getLvlRank();
533 assert(diffLvl <= lvlRank);
534 for (uint64_t l = diffLvl; l < lvlRank; l++) {
535 const uint64_t c = lvlCoords[l];
536 appendCrd(l, full, c);
540 values.push_back(val);
545 uint64_t lexDiff(
const uint64_t *lvlCoords)
const {
546 const uint64_t lvlRank = getLvlRank();
547 for (uint64_t l = 0; l < lvlRank; l++) {
548 const auto crd = lvlCoords[l];
549 const auto cur = lvlCursor[l];
550 if (crd > cur || (crd == cur && !isUniqueLvl(l)) ||
551 (crd < cur && !isOrderedLvl(l))) {
555 assert(
false &&
"non-lexicographic insertion");
559 assert(
false &&
"duplicate insertion");
564 std::vector<std::vector<P>> positions;
565 std::vector<std::vector<C>> coordinates;
566 std::vector<V> values;
569 std::vector<uint64_t> lvlCursor;
570 std::vector<C> crdBuffer;
579 template <
typename P,
typename C,
typename V>
581 uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
582 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
583 const uint64_t *dim2lvl,
const uint64_t *lvl2dim) {
586 lvlTypes, dim2lvl, lvl2dim, noLvlCOO);
589 template <
typename P,
typename C,
typename V>
591 uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
592 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
593 const uint64_t *dim2lvl,
const uint64_t *lvl2dim,
597 lvlTypes, dim2lvl, lvl2dim, lvlCOO);
600 template <
typename P,
typename C,
typename V>
602 uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
603 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
604 const uint64_t *dim2lvl,
const uint64_t *lvl2dim, uint64_t srcRank,
605 const intptr_t *buffers) {
607 lvlTypes, dim2lvl, lvl2dim, buffers);
616 template <
typename P,
typename C,
typename V>
618 uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
619 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
620 const uint64_t *dim2lvl,
const uint64_t *lvl2dim,
630 for (uint64_t l = 0; l < lvlRank; l++) {
632 positions[l].reserve(sz + 1);
633 positions[l].push_back(0);
634 coordinates[l].reserve(sz);
637 positions[l].reserve(2 * sz + 1);
638 positions[l].push_back(0);
639 coordinates[l].reserve(sz);
642 coordinates[l].reserve(sz);
645 assert(l == lvlRank - 1 &&
"unexpected n:m usage");
647 coordinates[l].reserve(sz);
656 assert(lvlCOO->
getRank() == lvlRank);
660 const uint64_t nse = elements.size();
661 assert(values.size() == 0);
663 fromCOO(elements, 0, nse, 0);
666 values.resize(sz, 0);
670 template <
typename P,
typename C,
typename V>
672 uint64_t dimRank,
const uint64_t *dimSizes, uint64_t lvlRank,
673 const uint64_t *lvlSizes,
const LevelType *lvlTypes,
674 const uint64_t *dim2lvl,
const uint64_t *lvl2dim,
const intptr_t *lvlBufs)
680 uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
681 for (uint64_t l = 0; l < lvlRank; l++) {
687 trailCOOLen = lvlRank - l;
691 P *posPtr =
reinterpret_cast<P *
>(lvlBufs[bufIdx++]);
692 C *crdPtr =
reinterpret_cast<C *
>(lvlBufs[bufIdx++]);
694 positions[l].assign(posPtr, posPtr + 2 * parentSz);
695 coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
697 positions[l].assign(posPtr, posPtr + parentSz + 1);
698 coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
701 assert(0 &&
"general singleton not supported yet");
703 assert(0 &&
"n ouf of m not supported yet");
707 parentSz = assembledSize(parentSz, l);
711 if (trailCOOLen != 0) {
712 uint64_t cooStartLvl = lvlRank - trailCOOLen;
715 P *posPtr =
reinterpret_cast<P *
>(lvlBufs[bufIdx++]);
716 C *aosCrdPtr =
reinterpret_cast<C *
>(lvlBufs[bufIdx++]);
719 positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
720 crdLen = positions[cooStartLvl][2 * parentSz - 1];
722 positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
723 crdLen = positions[cooStartLvl][parentSz];
725 for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
726 coordinates[l].resize(crdLen);
727 for (uint64_t n = 0; n < crdLen; n++) {
728 coordinates[l][n] = *(aosCrdPtr + (l - cooStartLvl) + n * trailCOOLen);
731 parentSz = assembledSize(parentSz, cooStartLvl);
735 V *valPtr =
reinterpret_cast<V *
>(lvlBufs[bufIdx]);
736 values.assign(valPtr, valPtr + parentSz);
#define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO)
#define MLIR_SPARSETENSOR_FOREVERY_V(DO)
#define DECL_GETCOORDINATESBUFFER(INAME, C)
Gets coordinates-overhead storage buffer for the given level.
#define DECL_LEXINSERT(VNAME, V)
Element-wise insertion in lexicographic coordinate order.
#define DECL_GETPOSITIONS(PNAME, P)
Gets positions-overhead storage for the given level.
#define DECL_EXPINSERT(VNAME, V)
Expanded insertion.
#define DECL_GETCOORDINATES(INAME, C)
Gets coordinates-overhead storage for the given level.
#define DECL_GETVALUES(VNAME, V)
Gets primary storage.
A class for capturing the sparse tensor type map with a compact encoding.
A memory-resident sparse tensor in coordinate-scheme representation (a collection of Elements).
const std::vector< Element< V > > & getElements() const
Gets the elements array.
void sort()
Sorts elements lexicographically by coordinates.
uint64_t getRank() const
Gets the dimension-rank of the tensor.
Abstract base class for SparseTensorStorage<P,C,V>.
uint64_t getLvlRank() const
Gets the number of storage-levels.
const std::vector< LevelType > & getLvlTypes() const
Gets the level-types array.
const std::vector< uint64_t > & getLvlSizes() const
Gets the storage-level sizes array.
bool isSingletonLvl(uint64_t l) const
Safely checks if the level uses singleton storage.
bool isCompressedLvl(uint64_t l) const
Safely checks if the level uses compressed storage.
SparseTensorStorageBase & operator=(const SparseTensorStorageBase &)=delete
SparseTensorStorageBase(const SparseTensorStorageBase &)=default
uint64_t getDimRank() const
Gets the number of tensor-dimensions.
const std::vector< uint64_t > & getDimSizes() const
Gets the tensor-dimension sizes array.
bool isDenseLvl(uint64_t l) const
Safely checks if the level uses dense storage.
uint64_t getLvlSize(uint64_t l) const
Safely looks up the size of the given storage-level.
virtual ~SparseTensorStorageBase()=default
bool isNOutOfMLvl(uint64_t l) const
Safely checks if the level uses n out of m storage.
uint64_t getDimSize(uint64_t d) const
Safely looks up the size of the given tensor-dimension.
bool isOrderedLvl(uint64_t l) const
Safely checks if the level is ordered.
LevelType getLvlType(uint64_t l) const
Safely looks up the type of the given level.
virtual void endLexInsert()=0
Finalizes lexicographic insertions.
bool isLooseCompressedLvl(uint64_t l) const
Safely checks if the level uses loose compressed storage.
bool isUniqueLvl(uint64_t l) const
Safely checks if the level is unique.
A memory-resident sparse tensor using a storage scheme based on per-level sparse/dense annotations.
void getValues(std::vector< V > **out) final
void getCoordinatesBuffer(std::vector< C > **out, uint64_t lvl) final
void getCoordinates(std::vector< C > **out, uint64_t lvl) final
static SparseTensorStorage< P, C, V > * newFromCOO(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const LevelType *lvlTypes, const uint64_t *dim2lvl, const uint64_t *lvl2dim, SparseTensorCOO< V > *lvlCOO)
Allocates a new sparse tensor and initializes it from the given COO.
void getPositions(std::vector< P > **out, uint64_t lvl) final
Partially specialize these getter methods based on template types.
void expInsert(uint64_t *lvlCoords, V *values, bool *filled, uint64_t *added, uint64_t count, uint64_t expsz) final
Partially specialize expanded insertions based on template types.
~SparseTensorStorage() final=default
void lexInsert(const uint64_t *lvlCoords, V val) final
Partially specialize lexicographical insertions based on template types.
static SparseTensorStorage< P, C, V > * newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const LevelType *lvlTypes, const uint64_t *dim2lvl, const uint64_t *lvl2dim)
Allocates a new empty sparse tensor.
static SparseTensorStorage< P, C, V > * newFromBuffers(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const LevelType *lvlTypes, const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers)
Allocates a new sparse tensor and initialize it from the given buffers.
void sortInPlace()
Sort the unordered tensor in place, the method assumes that it is an unordered COO tensor.
void endLexInsert() final
Finalizes lexicographic insertions.
uint64_t checkedMul(uint64_t lhs, uint64_t rhs)
A version of operator* on uint64_t which guards against overflows (when assertions are enabled).
bool isUniqueLT(LevelType lt)
bool isOrderedLT(LevelType lt)
bool isSingletonLT(LevelType lt)
bool isCompressedLT(LevelType lt)
bool isLooseCompressedLT(LevelType lt)
bool isDenseLT(LevelType lt)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
This enum defines all the sparse representations supportable by the SparseTensor dialect.