24 #define CMPI(p, lhs, rhs) \
25 (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41 (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
49 template <
bool hasPosBuffer>
53 using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54 std::array<Value, 1>>;
61 ValueRange getLvlBuffers()
const override {
return buffers; }
64 Value iv)
const override {
71 template <
typename T =
void,
typename = std::enable_if_t<hasPosBuffer, T>>
72 Value getPosBuf()
const {
76 Value getCrdBuf()
const {
77 if constexpr (hasPosBuffer)
83 const BufferT buffers;
88 DenseLevel(
unsigned tid,
Level lvl,
Value lvlSize)
92 llvm_unreachable(
"locate random-accessible level instead");
95 ValueRange getLvlBuffers()
const override {
return {}; }
99 assert(parentPos.size() == 1 &&
"Dense level can not be non-unique.");
100 assert(!inPadZone &&
"Not implemented");
101 Value p = parentPos.front();
103 return {posLo, lvlSize};
109 BatchLevel(
unsigned tid,
Level lvl,
Value lvlSize)
113 llvm_unreachable(
"locate random-accessible level instead");
116 ValueRange getLvlBuffers()
const override {
return {}; }
120 assert(!inPadZone &&
"Not implemented");
121 assert(parentPos.size() == 1 &&
"Dense level can not be non-unique.");
123 return {
C_IDX(0), lvlSize};
127 class CompressedLevel :
public SparseLevel<true> {
131 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
136 assert(parentPos.size() == 1 &&
137 "compressed level must be the first non-unique level.");
139 auto loadRange = [&b, l, parentPos, batchPrefix,
this]() ->
ValuePair {
140 Value p = parentPos.front();
149 if (inPadZone ==
nullptr)
153 scf::IfOp posRangeIf = b.
create<scf::IfOp>(l, types, inPadZone,
true);
159 b.
create<scf::YieldOp>(l, emptyRange);
163 auto [pLo, pHi] = loadRange();
165 b.
create<scf::YieldOp>(l, loadedRange);
168 ValueRange posRange = posRangeIf.getResults();
169 return {posRange.front(), posRange.back()};
173 class LooseCompressedLevel :
public SparseLevel<true> {
177 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
181 assert(parentPos.size() == 1 &&
182 "loose-compressed level must be the first non-unique level.");
183 assert(!inPadZone &&
"Not implemented");
185 Value p = parentPos.front();
195 class SingletonLevel :
public SparseLevel<false> {
199 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
203 assert(parentPos.size() == 1 || parentPos.size() == 2);
204 assert(!inPadZone &&
"Not implemented");
205 Value p = parentPos.front();
206 Value segHi = parentPos.size() == 2 ? parentPos.back() :
nullptr;
208 if (segHi ==
nullptr)
216 std::pair<Value, Value> parentRange)
const override {
222 class NOutOfMLevel :
public SparseLevel<false> {
226 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
230 assert(parentPos.size() == 1 &&
isUnique() &&
231 "n:m level can not be non-unique.");
232 assert(!inPadZone &&
"Not implemented");
251 auto ifOp = b.
create<scf::IfOp>(l, ifRetTypes, it.
genNotEnd(b, l),
true);
262 return ifOp.getResults();
298 unsigned cursorValCnt)
300 stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301 assert(getCursor().size() == cursorValCnt);
310 bool isBatchIterator()
const override {
313 bool randomAccessible()
const override {
329 class TrivialIterator :
public ConcreteIterator {
341 std::string getDebugInterfacePrefix()
const override {
342 return std::string(
"trivial<") + stl.
toString() +
">";
350 ret.push_back(getItPos());
351 if (randomAccessible()) {
354 ret.push_back(posLo);
356 ret.push_back(posHi);
362 assert(vs.size() == 2);
364 if (randomAccessible())
374 if (randomAccessible())
375 return {deref(b, l), upperBound(b, l)};
376 return std::make_pair(getItPos(), posHi);
381 return CMPI(ult, getItPos(), posHi);
385 if (randomAccessible()) {
386 updateCrd(
SUBI(getItPos(), posLo));
388 updateCrd(stl.
peekCrdAt(b, l, getBatchCrds(), getItPos()));
399 Value curPos = getCursor().front();
400 Value nxPos = forward(b, l).front();
401 seek(
SELECT(cond, nxPos, curPos));
406 assert(randomAccessible());
408 seek(
ADDI(crd, posLo));
410 if (isBatchIterator()) {
412 assert(batchCrds.size() > lvl);
413 batchCrds[lvl] = crd;
417 Value getItPos()
const {
return getCursor().front(); }
421 class DedupIterator :
public ConcreteIterator {
435 seek({posLo, genSegmentHigh(b, l, posLo)});
443 std::string getDebugInterfacePrefix()
const override {
444 return std::string(
"dedup<") + stl.
toString() +
">";
462 std::tie(posLo, posHi) = stl.
peekRangeAt(b, l, batchPrefix, pPos);
464 seek({posLo, genSegmentHigh(b, l, posLo)});
469 ret.append(getCursor().begin(), getCursor().end());
470 ret.push_back(posHi);
474 assert(vs.size() == 3);
475 seek(vs.take_front(getCursor().size()));
480 return CMPI(ult, getPos(), posHi);
484 updateCrd(stl.
peekCrdAt(b, l, getBatchCrds(), getPos()));
489 Value nxPos = getSegHi();
490 seek({nxPos, genSegmentHigh(b, l, nxPos)});
494 Value getPos()
const {
return getCursor()[0]; }
495 Value getSegHi()
const {
return getCursor()[1]; }
504 unsigned extraCursorVal = 0)
508 return wrap->getCursorValTypes(b);
510 bool isBatchIterator()
const override {
return wrap->isBatchIterator(); }
511 bool randomAccessible()
const override {
return wrap->randomAccessible(); };
512 bool iteratableByFor()
const override {
return wrap->iteratableByFor(); };
516 ValueRange getCurPosition()
const override {
return wrap->getCurPosition(); }
519 wrap->genInit(b, l, parent);
522 return wrap->genNotEndImpl(b, l);
525 return wrap->forward(b, l);
528 return wrap->upperBound(b, l);
532 return wrap->derefImpl(b, l);
536 return wrap->locate(b, l, crd);
542 std::unique_ptr<SparseIterator>
wrap;
549 class FilterIterator :
public SimpleWrapIterator {
554 return DIVUI(
SUBI(wrapCrd, offset), stride);
558 return ADDI(
MULI(crd, stride), offset);
568 FilterIterator(std::unique_ptr<SparseIterator> &&
wrap,
Value offset,
571 stride(stride), size(size) {}
578 std::string getDebugInterfacePrefix()
const override {
579 return std::string(
"filter<") +
wrap->getDebugInterfacePrefix() +
">";
582 bool iteratableByFor()
const override {
return randomAccessible(); };
587 wrap->genInit(b, l, parent);
588 if (!randomAccessible()) {
591 forwardIf(b, l, genShouldFilter(b, l));
595 wrap->locate(b, l, offset);
602 updateCrd(fromWrapCrd(b, l,
wrap->deref(b, l)));
607 assert(randomAccessible());
608 wrap->locate(b, l, toWrapCrd(b, l, crd));
614 Value offset, stride, size;
621 class PadIterator :
public SimpleWrapIterator {
624 PadIterator(std::unique_ptr<SparseIterator> &&
wrap,
Value padLow,
627 wrap->randomAccessible() ? 1 : 0),
628 padLow(padLow), padHigh(padHigh) {}
635 std::string getDebugInterfacePrefix()
const override {
636 return std::string(
"pad<") +
wrap->getDebugInterfacePrefix() +
">";
641 if (randomAccessible())
642 return {getCrd(), upperBound(b, l)};
643 return wrap->genForCond(b, l);
648 ValueRange getCurPosition()
const override {
return getCursor(); }
653 if (randomAccessible())
661 return ADDI(
ADDI(
wrap->upperBound(b, l), padLow), padHigh);
666 updateCrd(
ADDI(
wrap->deref(b, l), padLow));
671 assert(randomAccessible());
672 wrap->locate(b, l,
SUBI(crd, padLow));
677 getMutCursorVals().back() =
ORI(inPadLow, inPadHigh);
682 Value padLow, padHigh;
692 std::unique_ptr<SparseIterator> &&delegate,
695 parent(parent), delegate(std::move(delegate)),
696 tupleSz(this->delegate->
serialize().size()), subSectSz(subSectSz) {
697 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
700 maxTupleCnt =
C_IDX(1);
701 }
else if (p->lvl == lvl) {
703 maxTupleCnt = p->maxTupleCnt;
704 assert(
false &&
"Not implemented.");
707 assert(p->lvl + 1 == lvl);
708 maxTupleCnt =
MULI(p->maxTupleCnt, p->subSectSz);
712 if (randomAccessible())
714 subSectPosBuf = allocSubSectPosBuf(b, l);
722 std::string getDebugInterfacePrefix()
const override {
723 return std::string(
"ne_sub<") + delegate->getDebugInterfacePrefix() +
">";
735 return b.
create<memref::AllocaOp>(
743 b.
create<memref::StoreOp>(l, start, subSectPosBuf,
748 return b.
create<memref::LoadOp>(l, subSectPosBuf,
754 assert(itVals.size() == tupleSz);
755 for (
unsigned i = 0; i < tupleSz; i++) {
756 b.
create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
762 Value tupleId)
const {
764 for (
unsigned i = 0; i < tupleSz; i++) {
765 Value v = b.
create<memref::LoadOp>(l, subSectPosBuf,
772 bool isSubSectRoot()
const {
773 return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
779 TraverseBuilder builder)
const;
781 bool isBatchIterator()
const override {
return delegate->isBatchIterator(); }
782 bool randomAccessible()
const override {
783 return delegate->randomAccessible();
785 bool iteratableByFor()
const override {
return randomAccessible(); };
787 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
789 p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
799 delegate->locate(b, l, absOff);
801 assert(parent->
lvl + 1 == lvl);
808 return SUBI(wrapCrd, getAbsOff());
818 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
819 if (p && p->lvl == lvl)
820 crd =
SUBI(getAbsOff(), p->getAbsOff());
829 Value getMinCrd()
const {
return subSectMeta[0]; }
830 Value getAbsOff()
const {
return subSectMeta[1]; }
831 Value getNotEnd()
const {
return subSectMeta[2]; }
834 std::unique_ptr<SparseIterator> delegate;
837 const unsigned tupleSz;
839 Value maxTupleCnt, tupleCnt;
843 const Value subSectSz;
849 class SubSectIterator;
853 struct SubSectIterHelper {
854 explicit SubSectIterHelper(
const SubSectIterator &iter);
855 explicit SubSectIterHelper(
const NonEmptySubSectIterator &subSect);
864 const NonEmptySubSectIterator &subSect;
870 SubSectIterator(
const NonEmptySubSectIterator &subSect,
872 std::unique_ptr<SparseIterator> &&
wrap)
874 wrap->randomAccessible() ? 0 : 1),
875 subSect(subSect),
wrap(std::move(
wrap)), parent(parent), helper(*this) {
876 assert(subSect.tid == tid && subSect.lvl == lvl);
885 std::string getDebugInterfacePrefix()
const override {
886 return std::string(
"subsect<") +
wrap->getDebugInterfacePrefix() +
">";
890 if (!randomAccessible())
895 bool isBatchIterator()
const override {
return wrap->isBatchIterator(); }
896 bool randomAccessible()
const override {
return wrap->randomAccessible(); };
897 bool iteratableByFor()
const override {
return randomAccessible(); };
899 return subSect.subSectSz;
902 ValueRange getCurPosition()
const override {
return wrap->getCurPosition(); };
905 if (randomAccessible()) {
906 return ADDI(getCrd(), nxLvlTupleStart);
908 return ADDI(getCursor().back(), nxLvlTupleStart);
912 if (randomAccessible()) {
913 if (
auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
914 assert(p->lvl + 1 == lvl);
915 wrap->genInit(b, l, p);
917 nxLvlTupleStart =
MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
919 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
920 wrap->deserialize(subSect.delegate->serialize());
921 nxLvlTupleStart =
C_IDX(0);
925 assert(!randomAccessible());
926 assert(getCursor().size() ==
wrap->getCursor().size() + 1);
929 getMutCursorVals().back() =
C_IDX(0);
931 if (
auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
932 assert(p->lvl + 1 == lvl);
933 tupleId = p->getNxLvlTupleId(b, l);
935 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
938 nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
939 helper.deserializeFromTupleId(b, l, tupleId);
943 helper.locate(b, l, crd);
948 return helper.genNotEnd(b, l);
952 Value crd = helper.deref(b, l);
958 helper.forward(b, l);
959 assert(!randomAccessible());
960 assert(getCursor().size() ==
wrap->getCursor().size() + 1);
961 getMutCursorVals().back() =
ADDI(getCursor().back(),
C_IDX(1));
965 Value nxLvlTupleStart;
967 const NonEmptySubSectIterator &subSect;
968 std::unique_ptr<SparseIterator>
wrap;
971 SubSectIterHelper helper;
1011 args.push_back(crd);
1056 seek(ifOp.getResults());
1061 auto whileOp = b.
create<scf::WhileOp>(
1065 Value inBound =
CMPI(ult, ivs.front(), posHi);
1066 auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound,
true);
1070 b.setInsertionPointToStart(ifInBound.thenBlock());
1072 Value tailCrd = stl.
peekCrdAt(b, l, getBatchCrds(), ivs.front());
1073 Value isDup =
CMPI(eq, headCrd, tailCrd);
1076 b.setInsertionPointToStart(ifInBound.elseBlock());
1079 b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1092 Value crd = fromWrapCrd(b, l, wrapCrd);
1094 Value notlegit =
CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1096 notlegit =
ORI(
CMPI(ult, wrapCrd, offset), notlegit);
1098 notlegit =
ORI(
CMPI(uge, crd, size), notlegit);
1106 Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1109 return llvm::getSingleElement(r);
1113 assert(!
wrap->randomAccessible());
1117 Value crd = fromWrapCrd(b, l, wrapCrd);
1119 return {
CMPI(ult, crd, size)};
1121 return llvm::getSingleElement(r);
1125 assert(!randomAccessible());
1139 whileArgs.push_back(isFirst);
1140 auto whileOp = b.
create<scf::WhileOp>(
1151 genCrdNotLegitPredicate(b, l, wrapCrd);
1152 Value crd = fromWrapCrd(b, l, wrapCrd);
1154 ret =
ORI(ret, llvm::getSingleElement(isFirst));
1157 b.
create<scf::ConditionOp>(l, cont.front(), ivs);
1162 wrap->forward(b, l);
1164 yieldVals.push_back(
constantI1(b, l,
false));
1169 linkNewScope(whileOp.getResults());
1173 SubSectIterHelper::SubSectIterHelper(
const NonEmptySubSectIterator &subSect)
1174 : subSect(subSect),
wrap(*subSect.delegate) {}
1176 SubSectIterHelper::SubSectIterHelper(
const SubSectIterator &iter)
1177 : subSect(iter.subSect),
wrap(*iter.
wrap) {}
1181 assert(!subSect.randomAccessible());
1182 wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1186 Value absCrd =
ADDI(crd, subSect.getAbsOff());
1187 wrap.locate(b, l, absCrd);
1191 assert(!
wrap.randomAccessible());
1195 Value crd =
SUBI(wrapCrd, subSect.getAbsOff());
1197 return {
CMPI(ult, crd, subSect.subSectSz)};
1199 return llvm::getSingleElement(r);
1204 Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1209 return wrap.forward(b, l);
1212 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1215 SubSectIterHelper helper(*
this);
1216 if (!randomAccessible()) {
1220 iterArgs.push_back(
C_IDX(0));
1221 iterArgs.append(reduc.begin(), reduc.end());
1222 auto forEachLeaf = b.
create<scf::ForOp>(
1227 helper.deserializeFromTupleId(b, l, tupleId);
1229 Value cnt = iterArgs.front();
1233 helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1236 whileArgs.append(iterArgs.begin(), iterArgs.end());
1238 auto whileOp = b.create<scf::WhileOp>(
1242 helper.wrap.linkNewScope(ivs);
1243 b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1247 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1248 Value cnt = remIter.front();
1254 nxIter.append(userNx.begin(), userNx.end());
1257 ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1260 return forEachLeaf.
getResults().drop_front();
1263 assert(randomAccessible());
1268 assert(!parent || parent->lvl + 1 == lvl);
1269 delegate->genInit(b, l, parent);
1270 auto forOp = b.
create<scf::ForOp>(
1273 helper.locate(b, l, crd);
1280 if (isSubSectRoot()) {
1281 return visitDenseSubSect(b, l, parent, reduc);
1284 auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1285 assert(p->lvl + 1 == lvl);
1286 return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1292 if (isBatchIterator() && batchCrds.size() <= stl.
lvl)
1293 batchCrds.resize(stl.
lvl + 1,
nullptr);
1297 Value inPadZone =
nullptr;
1305 inPadZone = pPos.back();
1306 pPos = pPos.drop_back();
1311 std::tie(posLo, posHi) = stl.
peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1319 if (!isSubSectRoot()) {
1320 assert(parent->
lvl + 1 == lvl);
1321 if (randomAccessible()) {
1329 auto *p = cast<NonEmptySubSectIterator>(parent);
1340 assert(parent->
lvl + 1 == lvl && reduc.size() == 2);
1341 Value minCrd = reduc.front();
1342 Value tupleId = reduc.back();
1345 SubSectIterHelper helper(*
this);
1346 helper.wrap.genInit(b, l, parent);
1352 Value min = MINUI(crd, minCrd);
1358 storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1360 return {minCrd, tupleId};
1362 assert(result.size() == 2);
1363 tupleCnt = result.back();
1365 Value minCrd = result.front();
1368 seek({minCrd, absOff, notEnd});
1374 assert(isSubSectRoot());
1378 delegate->genInit(b, l, parent);
1379 if (randomAccessible()) {
1385 tupleCnt =
C_IDX(1);
1387 storeCursorVals(b, l, c0, delegate->serialize());
1390 b, l, *delegate, elseRet,
1393 return {crd, offset,
C_TRUE};
1400 assert(!randomAccessible());
1413 Value fastPathP =
CMPI(ugt, getMinCrd(), getAbsOff());
1414 auto ifOp = b.
create<scf::IfOp>(l, getCursor().getTypes(), fastPathP,
true);
1435 b, l, c0, tupleCnt, c1, loopArgs,
1438 Value tupleId = ivs.front();
1439 SubSectIterHelper helper(*
this);
1440 helper.deserializeFromTupleId(b, l, tupleId);
1443 b, l, *delegate, iterArgs,
1448 Value isMin =
CMPI(eq, crd, getMinCrd());
1449 delegate->forwardIf(b, l, isMin);
1451 auto ifIsMin = b.
create<scf::IfOp>(l, isMin,
false);
1453 storeCursorVals(b, l, tupleId, delegate->serialize());
1457 Value nxMin = iterArgs[0];
1468 scf::ForOp forOp = loopNest.
loops.front();
1471 Value nxMinCrd = forOp.getResult(0);
1472 Value nxNotEnd = forOp.getResult(1);
1477 Value nxMinCrd = ifOp.getResult(0);
1478 Value nxAbsOff = ifOp.getResult(1);
1479 Value nxNotEnd = ifOp.getResult(2);
1482 Value minAbsOff =
ADDI(getAbsOff(), c1);
1483 nxAbsOff = b.
create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1485 seek(
ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1487 Value crd = deref(b, l);
1488 nxNotEnd =
ANDI(nxNotEnd,
CMPI(ult, crd, upperBound(b, l)));
1490 seek(
ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1500 std::pair<Level, Level> lvlRange,
ValueRange parentPos)
1502 auto [lvlLo, lvlHi] = lvlRange;
1505 if (parentPos.empty())
1508 for (
Level lvl = lvlLo; lvl < lvlHi; lvl++)
1511 bound = lvls.front()->peekRangeAt(b, l, {}, parentPos);
1512 for (
auto &lvl :
getLvlRef().drop_front())
1513 bound = lvl->collapseRangeBetween(b, l, {}, bound);
1517 IterSpaceType dstTp,
ValueRange values,
unsigned int tid) {
1521 unsigned bufferCnt = 0;
1522 if (lt.isWithPosLT())
1524 if (lt.isWithCrdLT())
1527 ValueRange buffers = values.take_front(bufferCnt);
1528 values = values.drop_front(bufferCnt);
1531 Value sz = values.front();
1532 values = values.drop_front();
1533 space.lvls.push_back(
1537 space.bound = std::make_pair(values[0], values[1]);
1538 values = values.drop_front(2);
1541 assert(values.empty());
1545 std::unique_ptr<SparseIterator>
1555 std::unique_ptr<SparseTensorLevel>
1557 unsigned t,
Level l) {
1561 return std::make_unique<DenseLevel>(t, l, sz);
1563 return std::make_unique<BatchLevel>(t, l, sz);
1565 return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1567 return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1569 return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1571 return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1573 llvm_unreachable(
"undefined level format");
1575 llvm_unreachable(
"unrecognizable level format");
1578 std::unique_ptr<SparseTensorLevel>
1580 unsigned tid,
Level lvl) {
1584 Value sz = stt.hasEncoding() ? b.
create<LvlOp>(l, t, lvl).getResult()
1585 : b.
create<tensor::DimOp>(l, t, lvl).getResult();
1590 buffers.push_back(pos);
1594 buffers.push_back(pos);
1599 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1602 auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1603 auto it = std::make_unique<TrivialIterator>(*stl);
1605 return std::make_pair(std::move(stl), std::move(it));
1608 std::unique_ptr<SparseIterator>
1612 std::unique_ptr<SparseIterator> ret;
1616 ret = std::make_unique<DedupIterator>(b, l, iterSpace.
getLastLvl(),
1620 ret = std::make_unique<TrivialIterator>(b, l, iterSpace.
getLastLvl(),
1628 std::unique_ptr<SparseIterator>
1631 std::unique_ptr<SparseIterator> ret;
1635 ret = std::make_unique<DedupIterator>(stl);
1637 ret = std::make_unique<TrivialIterator>(stl);
1639 ret->setSparseEmitStrategy(strategy);
1643 std::unique_ptr<SparseIterator>
1649 std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1650 ret->setSparseEmitStrategy(strategy);
1654 std::unique_ptr<SparseIterator>
1658 auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1659 ret->setSparseEmitStrategy(strategy);
1664 auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1666 return &filter->getWrappedIterator();
1672 std::unique_ptr<SparseIterator> &&delegate,
Value size,
unsigned stride,
1677 std::unique_ptr<SparseIterator> it =
1678 std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1679 std::move(delegate), size);
1684 it = std::make_unique<FilterIterator>(std::move(it),
C_IDX(0),
1685 C_IDX(stride), loopBound);
1687 it->setSparseEmitStrategy(strategy);
1700 std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1704 it = std::make_unique<FilterIterator>(std::move(it),
C_IDX(0),
1705 C_IDX(stride), loopBound);
1707 it->setSparseEmitStrategy(strategy);
union mlir::linalg::@1191::ArityGroupAndKind::Kind kind
bool isUnique(It begin, It end)
#define SELECT(c, lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
SparseIterationSpace()=default
const SparseTensorLevel & getLastLvl() const
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
void updateCrd(Value crd)
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
SparseEmitStrategy emitStrategy
ValueRange getBatchCrds() const
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
ValueRange getCursor() const
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
void seek(ValueRange vals)
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
std::string toString() const
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
bool isUniqueLT(LevelType lt)
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
uint64_t getN(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context)
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr unsigned getNumBuffer() const
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.