24#define CMPI(p, lhs, rhs) \
25 (arith::CmpIOp::create(b, l, arith::CmpIPredicate::p, (lhs), (rhs)) \
28#define C_FALSE (constantI1(b, l, false))
29#define C_TRUE (constantI1(b, l, true))
30#define C_IDX(v) (constantIndex(b, l, (v)))
31#define YIELD(vs) (scf::YieldOp::create(b, l, (vs)))
32#define ADDI(lhs, rhs) (arith::AddIOp::create(b, l, (lhs), (rhs)).getResult())
33#define ORI(lhs, rhs) (arith::OrIOp::create(b, l, (lhs), (rhs)).getResult())
34#define ANDI(lhs, rhs) (arith::AndIOp::create(b, l, (lhs), (rhs)).getResult())
35#define SUBI(lhs, rhs) (arith::SubIOp::create(b, l, (lhs), (rhs)).getResult())
36#define MULI(lhs, rhs) (arith::MulIOp::create(b, l, (lhs), (rhs)).getResult())
37#define MINUI(lhs, rhs) (arith::MinUIOp::create(b, l, (lhs), (rhs)).getResult())
38#define REMUI(lhs, rhs) (arith::RemUIOp::create(b, l, (lhs), (rhs)).getResult())
39#define DIVUI(lhs, rhs) (arith::DivUIOp::create(b, l, (lhs), (rhs)).getResult())
40#define SELECT(c, lhs, rhs) \
41 (arith::SelectOp::create(b, l, (c), (lhs), (rhs)).getResult())
49template <
bool hasPosBuffer>
53 using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54 std::array<Value, 1>>;
61 ValueRange getLvlBuffers()
const override {
return buffers; }
64 Value iv)
const override {
71 template <
typename T =
void,
typename = std::enable_if_t<hasPosBuffer, T>>
72 Value getPosBuf()
const {
76 Value getCrdBuf()
const {
77 if constexpr (hasPosBuffer)
83 const BufferT buffers;
88 DenseLevel(
unsigned tid,
Level lvl,
Value lvlSize)
92 llvm_unreachable(
"locate random-accessible level instead");
95 ValueRange getLvlBuffers()
const override {
return {}; }
99 assert(parentPos.size() == 1 &&
"Dense level can not be non-unique.");
100 assert(!inPadZone &&
"Not implemented");
101 Value p = parentPos.front();
103 return {posLo, lvlSize};
109 BatchLevel(
unsigned tid,
Level lvl,
Value lvlSize)
113 llvm_unreachable(
"locate random-accessible level instead");
116 ValueRange getLvlBuffers()
const override {
return {}; }
120 assert(!inPadZone &&
"Not implemented");
121 assert(parentPos.size() == 1 &&
"Dense level can not be non-unique.");
123 return {
C_IDX(0), lvlSize};
127class CompressedLevel :
public SparseLevel<true> {
131 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
136 assert(parentPos.size() == 1 &&
137 "compressed level must be the first non-unique level.");
139 auto loadRange = [&
b, l, parentPos, batchPrefix,
this]() ->
ValuePair {
140 Value p = parentPos.front();
149 if (inPadZone ==
nullptr)
153 scf::IfOp posRangeIf = scf::IfOp::create(
b, l, types, inPadZone,
true);
156 b.setInsertionPointToStart(posRangeIf.thenBlock());
159 scf::YieldOp::create(
b, l, emptyRange);
162 b.setInsertionPointToStart(posRangeIf.elseBlock());
163 auto [pLo, pHi] = loadRange();
165 scf::YieldOp::create(
b, l, loadedRange);
167 b.setInsertionPointAfter(posRangeIf);
168 ValueRange posRange = posRangeIf.getResults();
169 return {posRange.front(), posRange.back()};
173class LooseCompressedLevel :
public SparseLevel<true> {
177 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
181 assert(parentPos.size() == 1 &&
182 "loose-compressed level must be the first non-unique level.");
183 assert(!inPadZone &&
"Not implemented");
185 Value p = parentPos.front();
195class SingletonLevel :
public SparseLevel<false> {
199 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
203 assert(parentPos.size() == 1 || parentPos.size() == 2);
204 assert(!inPadZone &&
"Not implemented");
205 Value p = parentPos.front();
206 Value segHi = parentPos.size() == 2 ? parentPos.back() :
nullptr;
208 if (segHi ==
nullptr)
216 std::pair<Value, Value> parentRange)
const override {
222class NOutOfMLevel :
public SparseLevel<false> {
226 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
230 assert(parentPos.size() == 1 &&
isUnique() &&
231 "n:m level can not be non-unique.");
232 assert(!inPadZone &&
"Not implemented");
251 auto ifOp = scf::IfOp::create(
b, l, ifRetTypes, it.
genNotEnd(
b, l),
true);
253 b.setInsertionPointToStart(ifOp.thenBlock());
258 b.setInsertionPointToStart(ifOp.elseBlock());
261 b.setInsertionPointAfter(ifOp);
262 return ifOp.getResults();
297 ConcreteIterator(
const SparseTensorLevel &stl,
IterKind kind,
298 unsigned cursorValCnt)
299 : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300 stl(stl), cursorValsStorage(cursorValCnt,
nullptr) {
301 assert(getCursor().size() == cursorValCnt);
306 static bool classof(
const SparseIterator *from) {
307 return from->
kind == IterKind::kTrivial;
310 bool isBatchIterator()
const override {
311 return stl.getLT().isa<LevelFormat::Batch>();
313 bool randomAccessible()
const override {
314 return stl.getLT().hasDenseSemantic();
316 bool iteratableByFor()
const override {
return kind != IterKind::kDedup; };
317 Value upperBound(OpBuilder &
b, Location l)
const override {
318 return stl.getSize();
322 const SparseTensorLevel &stl;
326 SmallVector<Value> cursorValsStorage;
329class TrivialIterator :
public ConcreteIterator {
331 TrivialIterator(
const SparseTensorLevel &stl)
334 TrivialIterator(OpBuilder &
b, Location l,
const SparseTensorLevel &stl,
335 Value posLo, Value posHi)
341 std::string getDebugInterfacePrefix()
const override {
342 return std::string(
"trivial<") + stl.
toString() +
">";
344 SmallVector<Type> getCursorValTypes(OpBuilder &
b)
const override {
345 return {
b.getIndexType()};
348 SmallVector<Value>
serialize()
const override {
349 SmallVector<Value> ret;
350 ret.push_back(getItPos());
351 if (randomAccessible()) {
354 ret.push_back(posLo);
356 ret.push_back(posHi);
362 assert(vs.size() == 2);
364 if (randomAccessible())
370 void genInitImpl(OpBuilder &
b, Location l,
371 const SparseIterator *parent)
override;
373 ValuePair genForCond(OpBuilder &
b, Location l)
override {
374 if (randomAccessible())
375 return {deref(
b, l), upperBound(
b, l)};
376 return std::make_pair(getItPos(), posHi);
379 Value genNotEndImpl(OpBuilder &
b, Location l)
override {
381 return CMPI(ult, getItPos(), posHi);
384 Value derefImpl(OpBuilder &
b, Location l)
override {
385 if (randomAccessible()) {
386 updateCrd(
SUBI(getItPos(), posLo));
388 updateCrd(stl.
peekCrdAt(
b, l, getBatchCrds(), getItPos()));
393 ValueRange forwardImpl(OpBuilder &
b, Location l)
override {
398 ValueRange forwardIf(OpBuilder &
b, Location l, Value cond)
override {
399 Value curPos = getCursor().front();
400 Value nxPos = forward(
b, l).front();
401 seek(
SELECT(cond, nxPos, curPos));
405 void locateImpl(OpBuilder &
b, Location l, Value crd)
override {
406 assert(randomAccessible());
408 seek(
ADDI(crd, posLo));
410 if (isBatchIterator()) {
412 assert(batchCrds.size() > lvl);
413 batchCrds[lvl] = crd;
417 Value getItPos()
const {
return getCursor().front(); }
421class DedupIterator :
public ConcreteIterator {
423 Value genSegmentHigh(OpBuilder &
b, Location l, Value pos);
426 DedupIterator(
const SparseTensorLevel &stl)
431 DedupIterator(OpBuilder &
b, Location l,
const SparseTensorLevel &stl,
432 Value posLo, Value posHi)
435 seek({posLo, genSegmentHigh(
b, l, posLo)});
439 static bool classof(
const SparseIterator *from) {
440 return from->
kind == IterKind::kDedup;
443 std::string getDebugInterfacePrefix()
const override {
444 return std::string(
"dedup<") + stl.
toString() +
">";
446 SmallVector<Type> getCursorValTypes(OpBuilder &
b)
const override {
447 return {
b.getIndexType(),
b.getIndexType()};
450 void genInitImpl(OpBuilder &
b, Location l,
451 const SparseIterator *parent)
override {
462 std::tie(posLo, posHi) = stl.
peekRangeAt(
b, l, batchPrefix, pPos);
464 seek({posLo, genSegmentHigh(
b, l, posLo)});
467 SmallVector<Value>
serialize()
const override {
468 SmallVector<Value> ret;
469 ret.append(getCursor().begin(), getCursor().end());
470 ret.push_back(posHi);
474 assert(vs.size() == 3);
475 seek(vs.take_front(getCursor().size()));
479 Value genNotEndImpl(OpBuilder &
b, Location l)
override {
480 return CMPI(ult, getPos(), posHi);
483 Value derefImpl(OpBuilder &
b, Location l)
override {
484 updateCrd(stl.
peekCrdAt(
b, l, getBatchCrds(), getPos()));
488 ValueRange forwardImpl(OpBuilder &
b, Location l)
override {
489 Value nxPos = getSegHi();
490 seek({nxPos, genSegmentHigh(
b, l, nxPos)});
494 Value getPos()
const {
return getCursor()[0]; }
495 Value getSegHi()
const {
return getCursor()[1]; }
503 SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap,
IterKind kind,
504 unsigned extraCursorVal = 0)
505 : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
508 wrap->setSparseEmitStrategy(strategy);
512 return wrap->getSparseEmitStrategy();
515 SmallVector<Type> getCursorValTypes(OpBuilder &
b)
const override {
516 return wrap->getCursorValTypes(
b);
518 bool isBatchIterator()
const override {
return wrap->isBatchIterator(); }
519 bool randomAccessible()
const override {
return wrap->randomAccessible(); };
520 bool iteratableByFor()
const override {
return wrap->iteratableByFor(); };
522 SmallVector<Value>
serialize()
const override {
return wrap->serialize(); };
524 ValueRange getCurPosition()
const override {
return wrap->getCurPosition(); }
525 void genInitImpl(OpBuilder &
b, Location l,
526 const SparseIterator *parent)
override {
527 wrap->genInit(
b, l, parent);
529 Value genNotEndImpl(OpBuilder &
b, Location l)
override {
530 return wrap->genNotEndImpl(
b, l);
532 ValueRange forwardImpl(OpBuilder &
b, Location l)
override {
533 return wrap->forward(
b, l);
535 Value upperBound(OpBuilder &
b, Location l)
const override {
536 return wrap->upperBound(
b, l);
539 Value derefImpl(OpBuilder &
b, Location l)
override {
540 return wrap->derefImpl(
b, l);
543 void locateImpl(OpBuilder &
b, Location l, Value crd)
override {
544 return wrap->locate(
b, l, crd);
547 SparseIterator &getWrappedIterator()
const {
return *wrap; }
550 std::unique_ptr<SparseIterator> wrap;
557class FilterIterator :
public SimpleWrapIterator {
560 Value fromWrapCrd(OpBuilder &
b, Location l, Value wrapCrd)
const {
562 return DIVUI(
SUBI(wrapCrd, offset), stride);
564 Value toWrapCrd(OpBuilder &
b, Location l, Value crd)
const {
566 return ADDI(
MULI(crd, stride), offset);
569 Value genCrdNotLegitPredicate(OpBuilder &
b, Location l, Value wrapCrd);
571 Value genShouldFilter(OpBuilder &
b, Location l);
576 FilterIterator(std::unique_ptr<SparseIterator> &&
wrap, Value offset,
577 Value stride, Value size)
579 stride(stride), size(size) {}
582 static bool classof(
const SparseIterator *from) {
583 return from->
kind == IterKind::kFilter;
586 std::string getDebugInterfacePrefix()
const override {
587 return std::string(
"filter<") +
wrap->getDebugInterfacePrefix() +
">";
590 bool iteratableByFor()
const override {
return randomAccessible(); };
591 Value upperBound(OpBuilder &
b, Location l)
const override {
return size; };
593 void genInitImpl(OpBuilder &
b, Location l,
594 const SparseIterator *parent)
override {
595 wrap->genInit(
b, l, parent);
596 if (!randomAccessible()) {
599 forwardIf(
b, l, genShouldFilter(
b, l));
603 wrap->locate(
b, l, offset);
607 Value genNotEndImpl(OpBuilder &
b, Location l)
override;
609 Value derefImpl(OpBuilder &
b, Location l)
override {
610 updateCrd(fromWrapCrd(
b, l,
wrap->deref(
b, l)));
614 void locateImpl(OpBuilder &
b, Location l, Value crd)
override {
615 assert(randomAccessible());
616 wrap->locate(
b, l, toWrapCrd(
b, l, crd));
620 ValueRange forwardImpl(OpBuilder &
b, Location l)
override;
622 Value offset, stride, size;
629class PadIterator :
public SimpleWrapIterator {
632 PadIterator(std::unique_ptr<SparseIterator> &&
wrap, Value padLow,
635 wrap->randomAccessible() ? 1 : 0),
636 padLow(padLow), padHigh(padHigh) {}
639 static bool classof(
const SparseIterator *from) {
640 return from->
kind == IterKind::kPad;
643 std::string getDebugInterfacePrefix()
const override {
644 return std::string(
"pad<") +
wrap->getDebugInterfacePrefix() +
">";
648 ValuePair genForCond(OpBuilder &
b, Location l)
override {
649 if (randomAccessible())
650 return {getCrd(), upperBound(
b, l)};
651 return wrap->genForCond(
b, l);
656 ValueRange getCurPosition()
const override {
return getCursor(); }
658 SmallVector<Type> getCursorValTypes(OpBuilder &
b)
const override {
659 SmallVector<Type> ret =
wrap->getCursorValTypes(
b);
661 if (randomAccessible())
662 ret.push_back(
b.getI1Type());
668 Value upperBound(OpBuilder &
b, Location l)
const override {
673 Value derefImpl(OpBuilder &
b, Location l)
override {
674 updateCrd(
ADDI(
wrap->deref(
b, l), padLow));
678 void locateImpl(OpBuilder &
b, Location l, Value crd)
override {
679 assert(randomAccessible());
683 Value inPadLow =
CMPI(ult, crd, padLow);
684 Value inPadHigh =
CMPI(uge, crd,
ADDI(
wrap->upperBound(
b, l), padLow));
685 getMutCursorVals().back() =
ORI(inPadLow, inPadHigh);
690 Value padLow, padHigh;
696 OpBuilder &, Location,
const SparseIterator *,
ValueRange)>;
698 NonEmptySubSectIterator(OpBuilder &
b, Location l,
699 const SparseIterator *parent,
700 std::unique_ptr<SparseIterator> &&delegate,
703 parent(parent), delegate(std::move(delegate)),
704 tupleSz(this->delegate->
serialize().size()), subSectSz(subSectSz) {
705 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
708 maxTupleCnt =
C_IDX(1);
709 }
else if (p->lvl == lvl) {
711 maxTupleCnt = p->maxTupleCnt;
712 assert(
false &&
"Not implemented.");
715 assert(p->lvl + 1 == lvl);
716 maxTupleCnt =
MULI(p->maxTupleCnt, p->subSectSz);
720 if (randomAccessible())
722 subSectPosBuf = allocSubSectPosBuf(
b, l);
726 static bool classof(
const SparseIterator *from) {
727 return from->
kind == IterKind::kNonEmptySubSect;
730 std::string getDebugInterfacePrefix()
const override {
731 return std::string(
"ne_sub<") + delegate->getDebugInterfacePrefix() +
">";
733 SmallVector<Type> getCursorValTypes(OpBuilder &
b)
const override {
735 return {
b.getIndexType(),
b.getIndexType(),
b.getI1Type()};
742 Value allocSubSectPosBuf(OpBuilder &
b, Location l) {
743 return memref::AllocaOp::create(
745 MemRefType::get({ShapedType::kDynamic, tupleSz + 1},
b.getIndexType()),
749 void storeNxLvlStart(OpBuilder &
b, Location l, Value tupleId,
751 memref::StoreOp::create(
b, l, start, subSectPosBuf,
755 Value loadNxLvlStart(OpBuilder &
b, Location l, Value tupleId)
const {
756 return memref::LoadOp::create(
b, l, subSectPosBuf,
760 void storeCursorVals(OpBuilder &
b, Location l, Value tupleId,
762 assert(itVals.size() == tupleSz);
763 for (
unsigned i = 0; i < tupleSz; i++) {
764 memref::StoreOp::create(
b, l, itVals[i], subSectPosBuf,
769 SmallVector<Value> loadCursorVals(OpBuilder &
b, Location l,
770 Value tupleId)
const {
771 SmallVector<Value> ret;
772 for (
unsigned i = 0; i < tupleSz; i++) {
773 Value v = memref::LoadOp::create(
b, l, subSectPosBuf,
780 bool isSubSectRoot()
const {
781 return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
787 TraverseBuilder builder)
const;
789 bool isBatchIterator()
const override {
return delegate->isBatchIterator(); }
790 bool randomAccessible()
const override {
791 return delegate->randomAccessible();
793 bool iteratableByFor()
const override {
return randomAccessible(); };
794 Value upperBound(OpBuilder &
b, Location l)
const override {
795 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
797 p && p->lvl == lvl ? p->upperBound(
b, l) : delegate->upperBound(
b, l);
801 void genInitImpl(OpBuilder &
b, Location l,
const SparseIterator *)
override;
803 void locateImpl(OpBuilder &
b, Location l, Value crd)
override {
807 delegate->locate(
b, l, absOff);
809 assert(parent->lvl + 1 == lvl);
815 Value toSubSectCrd(OpBuilder &
b, Location l, Value wrapCrd)
const {
816 return SUBI(wrapCrd, getAbsOff());
819 Value genNotEndImpl(OpBuilder &
b, Location l)
override {
823 Value derefImpl(OpBuilder &
b, Location l)
override {
826 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
827 if (p && p->lvl == lvl)
828 crd =
SUBI(getAbsOff(), p->getAbsOff());
835 ValueRange forwardImpl(OpBuilder &
b, Location l)
override;
837 Value getMinCrd()
const {
return subSectMeta[0]; }
838 Value getAbsOff()
const {
return subSectMeta[1]; }
839 Value getNotEnd()
const {
return subSectMeta[2]; }
841 const SparseIterator *parent;
842 std::unique_ptr<SparseIterator> delegate;
845 const unsigned tupleSz;
847 Value maxTupleCnt, tupleCnt;
851 const Value subSectSz;
854 SmallVector<Value, 3> subSectMeta{
nullptr,
nullptr,
nullptr};
857class SubSectIterator;
861struct SubSectIterHelper {
862 explicit SubSectIterHelper(
const SubSectIterator &iter);
863 explicit SubSectIterHelper(
const NonEmptySubSectIterator &subSect);
866 void deserializeFromTupleId(OpBuilder &
b, Location l, Value tupleId);
867 void locate(OpBuilder &
b, Location l, Value crd);
868 Value genNotEnd(OpBuilder &
b, Location l);
869 Value deref(OpBuilder &
b, Location l);
872 const NonEmptySubSectIterator &subSect;
873 SparseIterator &wrap;
878 SubSectIterator(
const NonEmptySubSectIterator &subSect,
879 const SparseIterator &parent,
880 std::unique_ptr<SparseIterator> &&wrap)
882 wrap->randomAccessible() ? 0 : 1),
883 subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
884 assert(subSect.tid == tid && subSect.lvl == lvl);
885 assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
889 static bool classof(
const SparseIterator *from) {
890 return from->
kind == IterKind::kSubSect;
893 std::string getDebugInterfacePrefix()
const override {
894 return std::string(
"subsect<") + wrap->getDebugInterfacePrefix() +
">";
896 SmallVector<Type> getCursorValTypes(OpBuilder &
b)
const override {
897 SmallVector<Type> ret = wrap->getCursorValTypes(
b);
898 if (!randomAccessible())
899 ret.push_back(
b.getIndexType());
903 bool isBatchIterator()
const override {
return wrap->isBatchIterator(); }
904 bool randomAccessible()
const override {
return wrap->randomAccessible(); };
905 bool iteratableByFor()
const override {
return randomAccessible(); };
906 Value upperBound(OpBuilder &
b, Location l)
const override {
907 return subSect.subSectSz;
910 ValueRange getCurPosition()
const override {
return wrap->getCurPosition(); };
912 Value getNxLvlTupleId(OpBuilder &
b, Location l)
const {
913 if (randomAccessible()) {
914 return ADDI(getCrd(), nxLvlTupleStart);
916 return ADDI(getCursor().back(), nxLvlTupleStart);
919 void genInitImpl(OpBuilder &
b, Location l,
const SparseIterator *)
override {
920 if (randomAccessible()) {
921 if (
auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
922 assert(p->lvl + 1 == lvl);
923 wrap->genInit(
b, l, p);
925 nxLvlTupleStart =
MULI(subSect.subSectSz, p->getNxLvlTupleId(
b, l));
927 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
928 wrap->deserialize(subSect.delegate->serialize());
929 nxLvlTupleStart =
C_IDX(0);
933 assert(!randomAccessible());
934 assert(getCursor().size() == wrap->getCursor().size() + 1);
937 getMutCursorVals().back() =
C_IDX(0);
939 if (
auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
940 assert(p->lvl + 1 == lvl);
941 tupleId = p->getNxLvlTupleId(
b, l);
943 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
946 nxLvlTupleStart = subSect.loadNxLvlStart(
b, l, tupleId);
947 helper.deserializeFromTupleId(
b, l, tupleId);
950 void locateImpl(OpBuilder &
b, Location l, Value crd)
override {
951 helper.locate(
b, l, crd);
955 Value genNotEndImpl(OpBuilder &
b, Location l)
override {
956 return helper.genNotEnd(
b, l);
959 Value derefImpl(OpBuilder &
b, Location l)
override {
960 Value crd = helper.deref(
b, l);
965 ValueRange forwardImpl(OpBuilder &
b, Location l)
override {
966 helper.forward(
b, l);
967 assert(!randomAccessible());
968 assert(getCursor().size() == wrap->getCursor().size() + 1);
969 getMutCursorVals().back() =
ADDI(getCursor().back(),
C_IDX(1));
973 Value nxLvlTupleStart;
975 const NonEmptySubSectIterator &subSect;
976 std::unique_ptr<SparseIterator> wrap;
977 const SparseIterator &parent;
979 SubSectIterHelper helper;
989 const SparseIterator *p) {
992 Operation *begin =
b.create(l,
b.getStringAttr(prefix +
".begin"), {},
1007 Operation *notEnd =
b.create(l,
b.getStringAttr(prefix +
".not_end"),
1019 args.push_back(crd);
1045 Operation *next =
b.create(l,
b.getStringAttr(prefix +
".next"),
1054 auto ifOp = scf::IfOp::create(
b, l,
getCursor().getTypes(), cond,
true);
1057 b.setInsertionPointToStart(ifOp.elseBlock());
1060 b.setInsertionPointToStart(ifOp.thenBlock());
1063 b.setInsertionPointAfter(ifOp);
1064 seek(ifOp.getResults());
1069 auto whileOp = scf::WhileOp::create(
1073 Value inBound = CMPI(ult, ivs.front(), posHi);
1074 auto ifInBound = scf::IfOp::create(b, l, b.getI1Type(), inBound, true);
1076 OpBuilder::InsertionGuard guard(b);
1078 b.setInsertionPointToStart(ifInBound.thenBlock());
1079 Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1080 Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1081 Value isDup = CMPI(eq, headCrd, tailCrd);
1084 b.setInsertionPointToStart(ifInBound.elseBlock());
1085 YIELD(constantI1(b, l, false));
1087 scf::ConditionOp::create(
b, l, ifInBound.getResults()[0], ivs);
1095 return whileOp.getResult(0);
1098Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &
b, Location l,
1100 Value crd = fromWrapCrd(
b, l, wrapCrd);
1102 Value notlegit =
CMPI(ne, toWrapCrd(
b, l, crd), wrapCrd);
1104 notlegit =
ORI(
CMPI(ult, wrapCrd, offset), notlegit);
1106 notlegit =
ORI(
CMPI(uge, crd, size), notlegit);
1110Value FilterIterator::genShouldFilter(OpBuilder &
b, Location l) {
1114 Value notLegit = genCrdNotLegitPredicate(
b, l, wrapCrd);
1117 return llvm::getSingleElement(r);
1120Value FilterIterator::genNotEndImpl(OpBuilder &
b, Location l) {
1121 assert(!
wrap->randomAccessible());
1125 Value crd = fromWrapCrd(
b, l, wrapCrd);
1127 return {
CMPI(ult, crd, size)};
1129 return llvm::getSingleElement(r);
1132ValueRange FilterIterator::forwardImpl(OpBuilder &
b, Location l) {
1133 assert(!randomAccessible());
1146 SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1147 whileArgs.push_back(isFirst);
1148 auto whileOp = scf::WhileOp::create(
1149 b, l,
ValueRange(whileArgs).getTypes(), whileArgs,
1151 [
this](OpBuilder &
b, Location l,
ValueRange ivs) {
1155 [
this, isFirst](OpBuilder &
b, Location l,
1159 genCrdNotLegitPredicate(
b, l, wrapCrd);
1160 Value crd = fromWrapCrd(
b, l, wrapCrd);
1161 Value ret =
ANDI(
CMPI(ult, crd, size), notLegit);
1162 ret =
ORI(ret, llvm::getSingleElement(isFirst));
1165 scf::ConditionOp::create(
b, l, cont.front(), ivs);
1168 [
this](OpBuilder &
b, Location l,
ValueRange ivs) {
1170 wrap->forward(
b, l);
1171 SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1176 b.setInsertionPointAfter(whileOp);
1177 linkNewScope(whileOp.getResults());
1181SubSectIterHelper::SubSectIterHelper(
const NonEmptySubSectIterator &subSect)
1182 : subSect(subSect),
wrap(*subSect.delegate) {}
1184SubSectIterHelper::SubSectIterHelper(
const SubSectIterator &iter)
1185 : subSect(iter.subSect),
wrap(*iter.
wrap) {}
1187void SubSectIterHelper::deserializeFromTupleId(OpBuilder &
b, Location l,
1189 assert(!subSect.randomAccessible());
1190 wrap.deserialize(subSect.loadCursorVals(
b, l, tupleId));
1193void SubSectIterHelper::locate(OpBuilder &
b, Location l, Value crd) {
1194 Value absCrd =
ADDI(crd, subSect.getAbsOff());
1195 wrap.locate(
b, l, absCrd);
1198Value SubSectIterHelper::genNotEnd(OpBuilder &
b, Location l) {
1199 assert(!
wrap.randomAccessible());
1203 Value crd =
SUBI(wrapCrd, subSect.getAbsOff());
1205 return {
CMPI(ult, crd, subSect.subSectSz)};
1207 return llvm::getSingleElement(r);
1210Value SubSectIterHelper::deref(OpBuilder &
b, Location l) {
1211 Value wrapCrd =
wrap.deref(
b, l);
1212 Value crd = subSect.toSubSectCrd(
b, l, wrapCrd);
1216ValueRange SubSectIterHelper::forward(OpBuilder &
b, Location l) {
1217 return wrap.forward(
b, l);
1220ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1221 OpBuilder &
b, Location l,
ValueRange reduc, TraverseBuilder builder)
const {
1223 SubSectIterHelper helper(*
this);
1224 if (!randomAccessible()) {
1227 SmallVector<Value> iterArgs;
1228 iterArgs.push_back(
C_IDX(0));
1229 iterArgs.append(reduc.begin(), reduc.end());
1230 auto forEachLeaf = scf::ForOp::create(
1232 [&helper, &builder](OpBuilder &
b, Location l, Value tupleId,
1235 helper.deserializeFromTupleId(
b, l, tupleId);
1237 Value cnt = iterArgs.front();
1241 helper.subSect.storeNxLvlStart(
b, l, tupleId, cnt);
1243 SmallVector<Value> whileArgs(helper.wrap.getCursor());
1244 whileArgs.append(iterArgs.begin(), iterArgs.end());
1246 auto whileOp = scf::WhileOp::create(
1247 b, l,
ValueRange(whileArgs).getTypes(), whileArgs,
1249 [&helper](OpBuilder &
b, Location l,
ValueRange ivs) {
1250 helper.wrap.linkNewScope(ivs);
1251 scf::ConditionOp::create(
b, l, helper.genNotEnd(
b, l), ivs);
1254 [&helper, &builder](OpBuilder &
b, Location l,
ValueRange ivs) {
1255 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1256 Value cnt = remIter.front();
1260 SmallVector<Value> nxIter = helper.forward(
b, l);
1262 nxIter.append(userNx.begin(), userNx.end());
1265 ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1268 return forEachLeaf.getResults().drop_front();
1271 assert(randomAccessible());
1273 auto visitDenseSubSect = [&,
this](OpBuilder &
b, Location l,
1276 assert(!parent || parent->lvl + 1 == lvl);
1277 delegate->genInit(
b, l, parent);
1278 auto forOp = scf::ForOp::create(
1280 [&](OpBuilder &
b, Location l, Value crd,
ValueRange iterArgs) {
1281 helper.locate(
b, l, crd);
1285 return forOp.getResults();
1288 if (isSubSectRoot()) {
1289 return visitDenseSubSect(
b, l, parent, reduc);
1292 auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1293 assert(p->lvl + 1 == lvl);
1294 return p->inflateSubSectTree(
b, l, reduc, visitDenseSubSect);
1297void TrivialIterator::genInitImpl(OpBuilder &
b, Location l,
1300 if (isBatchIterator() && batchCrds.size() <= stl.
lvl)
1301 batchCrds.resize(stl.
lvl + 1,
nullptr);
1303 Value c0 =
C_IDX(0);
1305 Value inPadZone =
nullptr;
1313 inPadZone = pPos.back();
1314 pPos = pPos.drop_back();
1319 std::tie(posLo, posHi) = stl.
peekRangeAt(
b, l, batchPrefix, pPos, inPadZone);
1324void NonEmptySubSectIterator::genInitImpl(OpBuilder &
b, Location l,
1326 Value c0 =
C_IDX(0);
1327 if (!isSubSectRoot()) {
1328 assert(parent->
lvl + 1 == lvl);
1329 if (randomAccessible()) {
1337 auto *p = cast<NonEmptySubSectIterator>(parent);
1338 SmallVector<Value, 3> reduc = {
1348 assert(parent->
lvl + 1 == lvl && reduc.size() == 2);
1349 Value minCrd = reduc.front();
1350 Value tupleId = reduc.back();
1353 SubSectIterHelper helper(*
this);
1354 helper.wrap.genInit(
b, l, parent);
1358 [minCrd](OpBuilder &
b, Location l,
1360 Value min = MINUI(crd, minCrd);
1366 storeCursorVals(
b, l, tupleId, helper.wrap.serialize());
1368 return {minCrd, tupleId};
1370 assert(
result.size() == 2);
1371 tupleCnt =
result.back();
1373 Value minCrd =
result.front();
1375 Value notEnd =
CMPI(ne, minCrd,
C_IDX(-1));
1376 seek({minCrd, absOff, notEnd});
1382 assert(isSubSectRoot());
1386 delegate->genInit(
b, l, parent);
1387 if (randomAccessible()) {
1393 tupleCnt =
C_IDX(1);
1395 storeCursorVals(
b, l, c0, delegate->serialize());
1396 SmallVector<Value> elseRet{c0, c0,
C_FALSE};
1398 b, l, *delegate, elseRet,
1401 return {crd, offset,
C_TRUE};
1407ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &
b, Location l) {
1408 assert(!randomAccessible());
1421 Value fastPathP =
CMPI(ugt, getMinCrd(), getAbsOff());
1422 auto ifOp = scf::IfOp::create(
b, l, getCursor().getTypes(), fastPathP,
true);
1424 OpBuilder::InsertionGuard guard(
b);
1428 b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1429 Value nxOffset =
ADDI(getAbsOff(), c1);
1439 b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1440 SmallVector<Value, 2> loopArgs{
C_IDX(-1),
1443 b, l, c0, tupleCnt, c1, loopArgs,
1446 Value tupleId = ivs.front();
1447 SubSectIterHelper helper(*
this);
1448 helper.deserializeFromTupleId(
b, l, tupleId);
1451 b, l, *delegate, iterArgs,
1452 [
this, iterArgs, tupleId](OpBuilder &
b, Location l,
1456 Value isMin =
CMPI(eq, crd, getMinCrd());
1457 delegate->forwardIf(
b, l, isMin);
1459 auto ifIsMin = scf::IfOp::create(
b, l, isMin,
false);
1460 b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1461 storeCursorVals(
b, l, tupleId, delegate->serialize());
1462 b.setInsertionPointAfter(ifIsMin);
1465 Value nxMin = iterArgs[0];
1467 [nxMin](OpBuilder &
b, Location l,
1469 Value nx = arith::MinUIOp::create(
1476 scf::ForOp forOp = loopNest.
loops.front();
1477 b.setInsertionPointAfter(forOp);
1479 Value nxMinCrd = forOp.getResult(0);
1480 Value nxNotEnd = forOp.getResult(1);
1485 Value nxMinCrd = ifOp.getResult(0);
1486 Value nxAbsOff = ifOp.getResult(1);
1487 Value nxNotEnd = ifOp.getResult(2);
1490 Value minAbsOff =
ADDI(getAbsOff(), c1);
1491 nxAbsOff = arith::MaxUIOp::create(
b, l, minAbsOff, nxAbsOff);
1493 seek(
ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1495 Value crd = deref(
b, l);
1496 nxNotEnd =
ANDI(nxNotEnd,
CMPI(ult, crd, upperBound(
b, l)));
1498 seek(
ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1508 std::pair<Level, Level> lvlRange,
ValueRange parentPos)
1510 auto [lvlLo, lvlHi] = lvlRange;
1513 if (parentPos.empty())
1516 for (
Level lvl = lvlLo; lvl < lvlHi; lvl++)
1519 bound = lvls.front()->peekRangeAt(
b, l, {}, parentPos);
1520 for (
auto &lvl :
getLvlRef().drop_front())
1521 bound = lvl->collapseRangeBetween(
b, l, {}, bound);
1525 IterSpaceType dstTp,
ValueRange values,
unsigned int tid) {
1528 for (
auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1529 unsigned bufferCnt = 0;
1530 if (lt.isWithPosLT())
1532 if (lt.isWithCrdLT())
1535 ValueRange buffers = values.take_front(bufferCnt);
1536 values = values.drop_front(bufferCnt);
1539 Value sz = values.front();
1540 values = values.drop_front();
1541 space.lvls.push_back(
1545 space.bound = std::make_pair(values[0], values[1]);
1546 values = values.drop_front(2);
1549 assert(values.empty());
1553std::unique_ptr<SparseIterator>
1563std::unique_ptr<SparseTensorLevel>
1565 unsigned t,
Level l) {
1569 return std::make_unique<DenseLevel>(t, l, sz);
1571 return std::make_unique<BatchLevel>(t, l, sz);
1573 return std::make_unique<CompressedLevel>(t, l, lt, sz,
b[0],
b[1]);
1575 return std::make_unique<LooseCompressedLevel>(t, l, lt, sz,
b[0],
b[1]);
1577 return std::make_unique<SingletonLevel>(t, l, lt, sz,
b[0]);
1579 return std::make_unique<NOutOfMLevel>(t, l, lt, sz,
b[0]);
1581 llvm_unreachable(
"undefined level format");
1583 llvm_unreachable(
"unrecognizable level format");
1586std::unique_ptr<SparseTensorLevel>
1588 unsigned tid,
Level lvl) {
1592 Value sz = stt.hasEncoding()
1593 ? LvlOp::create(
b, l, t, lvl).getResult()
1594 : tensor::DimOp::create(
b, l, t, lvl).getResult();
1598 Value pos = ToPositionsOp::create(
b, l, t, lvl);
1599 buffers.push_back(pos);
1602 Value pos = ToCoordinatesOp::create(
b, l, t, lvl);
1603 buffers.push_back(pos);
1608std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1611 auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1612 auto it = std::make_unique<TrivialIterator>(*stl);
1614 return std::make_pair(std::move(stl), std::move(it));
1617std::unique_ptr<SparseIterator>
1621 std::unique_ptr<SparseIterator> ret;
1625 ret = std::make_unique<DedupIterator>(
b, l, iterSpace.
getLastLvl(),
1629 ret = std::make_unique<TrivialIterator>(
b, l, iterSpace.
getLastLvl(),
1637std::unique_ptr<SparseIterator>
1640 std::unique_ptr<SparseIterator> ret;
1644 ret = std::make_unique<DedupIterator>(stl);
1646 ret = std::make_unique<TrivialIterator>(stl);
1648 ret->setSparseEmitStrategy(strategy);
1652std::unique_ptr<SparseIterator>
1658 std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1659 ret->setSparseEmitStrategy(strategy);
1663std::unique_ptr<SparseIterator>
1667 auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1668 ret->setSparseEmitStrategy(strategy);
1673 auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1675 return &filter->getWrappedIterator();
1681 std::unique_ptr<SparseIterator> &&delegate,
Value size,
unsigned stride,
1686 std::unique_ptr<SparseIterator> it =
1687 std::make_unique<NonEmptySubSectIterator>(
b, l, parent,
1688 std::move(delegate), size);
1693 it = std::make_unique<FilterIterator>(std::move(it),
C_IDX(0),
1694 C_IDX(stride), loopBound);
1696 it->setSparseEmitStrategy(strategy);
1709 std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1713 it = std::make_unique<FilterIterator>(std::move(it),
C_IDX(0),
1714 C_IDX(stride), loopBound);
1716 it->setSparseEmitStrategy(strategy);
static bool isUnique(It begin, It end)
#define SELECT(c, lhs, rhs)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
std::tuple< Value, Value, Value > ValueTuple
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
SparseIterationSpace()=default
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
const SparseTensorLevel & getLastLvl() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
void updateCrd(Value crd)
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
ValueRange getBatchCrds() const
virtual Value derefImpl(OpBuilder &b, Location l)=0
virtual void setSparseEmitStrategy(SparseEmitStrategy strategy)
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual SparseEmitStrategy getSparseEmitStrategy() const
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
ValueRange getCursor() const
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
void seek(ValueRange vals)
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
std::string toString() const
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
bool isUniqueLT(LevelType lt)
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t getN(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
uint64_t Level
The type of level identifiers and level-ranks.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
LogicalResult serialize(ModuleOp moduleOp, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V moduleOp and writes to binary.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options={})
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr unsigned getNumBuffer() const
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.