24 #define CMPI(p, lhs, rhs) \
25 (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41 (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
49 template <
bool hasPosBuffer>
53 using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54 std::array<Value, 1>>;
61 ValueRange getLvlBuffers()
const override {
return buffers; }
64 Value iv)
const override {
71 template <
typename T =
void,
typename = std::enable_if_t<hasPosBuffer, T>>
72 Value getPosBuf()
const {
76 Value getCrdBuf()
const {
77 if constexpr (hasPosBuffer)
83 const BufferT buffers;
88 DenseLevel(
unsigned tid,
Level lvl,
Value lvlSize)
92 llvm_unreachable(
"locate random-accessible level instead");
95 ValueRange getLvlBuffers()
const override {
return {}; }
99 assert(parentPos.size() == 1 &&
"Dense level can not be non-unique.");
100 Value p = parentPos.front();
102 return {posLo, lvlSize};
108 BatchLevel(
unsigned tid,
Level lvl,
Value lvlSize)
112 llvm_unreachable(
"locate random-accessible level instead");
115 ValueRange getLvlBuffers()
const override {
return {}; }
119 assert(parentPos.size() == 1 &&
"Dense level can not be non-unique.");
121 return {
C_IDX(0), lvlSize};
125 class CompressedLevel :
public SparseLevel<true> {
129 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
134 assert(parentPos.size() == 1 &&
135 "compressed level must be the first non-unique level.");
136 Value p = parentPos.front();
147 class LooseCompressedLevel :
public SparseLevel<true> {
151 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
155 assert(parentPos.size() == 1 &&
156 "loose-compressed level must be the first non-unique level.");
158 Value p = parentPos.front();
168 class SingletonLevel :
public SparseLevel<false> {
172 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
176 assert(parentPos.size() == 1 || parentPos.size() == 2);
177 Value p = parentPos.front();
178 Value segHi = parentPos.size() == 2 ? parentPos.back() :
nullptr;
180 if (segHi ==
nullptr)
187 class NOutOfMLevel :
public SparseLevel<false> {
191 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
195 assert(parentPos.size() == 1 &&
isUnique() &&
196 "n:m level can not be non-unique.");
215 auto ifOp = b.
create<scf::IfOp>(l, ifRetTypes, it.
genNotEnd(b, l),
true);
226 return ifOp.getResults();
262 unsigned cursorValCnt)
263 :
SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
264 stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
265 assert(getCursor().size() == cursorValCnt);
274 bool isBatchIterator()
const override {
277 bool randomAccessible()
const override {
293 class TrivialIterator :
public ConcreteIterator {
298 std::string getDebugInterfacePrefix()
const override {
299 return std::string(
"trivial<") + stl.
toString() +
">";
307 ret.push_back(getItPos());
308 if (randomAccessible()) {
311 ret.push_back(posLo);
313 ret.push_back(posHi);
319 assert(vs.size() == 2);
321 if (randomAccessible())
330 if (isBatchIterator() && batchCrds.size() <= stl.
lvl)
331 batchCrds.resize(stl.
lvl + 1,
nullptr);
341 std::tie(posLo, posHi) = stl.
peekRangeAt(b, l, batchPrefix, pPos);
347 if (randomAccessible())
348 return {deref(b, l), upperBound(b, l)};
349 return std::make_pair(getItPos(), posHi);
354 return CMPI(ult, getItPos(), posHi);
358 if (randomAccessible()) {
359 updateCrd(
SUBI(getItPos(), posLo));
361 updateCrd(stl.
peekCrdAt(b, l, getBatchCrds(), getItPos()));
372 Value curPos = getCursor().front();
373 Value nxPos = forward(b, l).front();
374 seek(
SELECT(cond, nxPos, curPos));
379 assert(randomAccessible());
381 seek(
ADDI(crd, posLo));
383 if (isBatchIterator()) {
385 assert(batchCrds.size() > lvl);
386 batchCrds[lvl] = crd;
390 Value getItPos()
const {
return getCursor().front(); }
394 class DedupIterator :
public ConcreteIterator {
408 std::string getDebugInterfacePrefix()
const override {
409 return std::string(
"dedup<") + stl.
toString() +
">";
427 std::tie(posLo, posHi) = stl.
peekRangeAt(b, l, batchPrefix, pPos);
429 seek({posLo, genSegmentHigh(b, l, posLo)});
434 ret.append(getCursor().begin(), getCursor().end());
435 ret.push_back(posHi);
439 assert(vs.size() == 3);
440 seek(vs.take_front(getCursor().size()));
445 return CMPI(ult, getPos(), posHi);
449 updateCrd(stl.
peekCrdAt(b, l, getBatchCrds(), getPos()));
454 Value nxPos = getSegHi();
455 seek({nxPos, genSegmentHigh(b, l, nxPos)});
459 Value getPos()
const {
return getCursor()[0]; }
460 Value getSegHi()
const {
return getCursor()[1]; }
474 return DIVUI(
SUBI(wrapCrd, offset), stride);
478 return ADDI(
MULI(crd, stride), offset);
488 FilterIterator(std::unique_ptr<SparseIterator> &&
wrap,
Value offset,
491 stride(stride), size(size),
wrap(std::move(
wrap)) {}
498 std::string getDebugInterfacePrefix()
const override {
499 return std::string(
"filter<") +
wrap->getDebugInterfacePrefix() +
">";
502 return wrap->getCursorValTypes(b);
505 bool isBatchIterator()
const override {
return wrap->isBatchIterator(); }
506 bool randomAccessible()
const override {
return wrap->randomAccessible(); };
507 bool iteratableByFor()
const override {
return randomAccessible(); };
512 ValueRange getCurPosition()
const override {
return wrap->getCurPosition(); }
516 wrap->genInit(b, l, parent);
517 if (!randomAccessible()) {
520 forwardIf(b, l, genShouldFilter(b, l));
524 wrap->locate(b, l, offset);
531 updateCrd(fromWrapCrd(b, l,
wrap->deref(b, l)));
536 assert(randomAccessible());
537 wrap->locate(b, l, toWrapCrd(b, l, crd));
543 Value offset, stride, size;
544 std::unique_ptr<SparseIterator>
wrap;
554 std::unique_ptr<SparseIterator> &&delegate,
557 parent(parent), delegate(std::move(delegate)),
558 tupleSz(this->delegate->
serialize().size()), subSectSz(subSectSz) {
559 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
562 maxTupleCnt =
C_IDX(1);
563 }
else if (p->lvl == lvl) {
565 maxTupleCnt = p->maxTupleCnt;
566 assert(
false &&
"Not implemented.");
569 assert(p->lvl + 1 == lvl);
570 maxTupleCnt =
MULI(p->maxTupleCnt, p->subSectSz);
574 if (randomAccessible())
576 subSectPosBuf = allocSubSectPosBuf(b, l);
584 std::string getDebugInterfacePrefix()
const override {
585 return std::string(
"ne_sub<") + delegate->getDebugInterfacePrefix() +
">";
597 return b.
create<memref::AllocaOp>(
605 b.
create<memref::StoreOp>(l, start, subSectPosBuf,
610 return b.
create<memref::LoadOp>(l, subSectPosBuf,
616 assert(itVals.size() == tupleSz);
617 for (
unsigned i = 0; i < tupleSz; i++) {
618 b.
create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
624 Value tupleId)
const {
626 for (
unsigned i = 0; i < tupleSz; i++) {
627 Value v = b.
create<memref::LoadOp>(l, subSectPosBuf,
634 bool isSubSectRoot()
const {
635 return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
641 TraverseBuilder builder)
const;
643 bool isBatchIterator()
const override {
return delegate->isBatchIterator(); }
644 bool randomAccessible()
const override {
645 return delegate->randomAccessible();
647 bool iteratableByFor()
const override {
return randomAccessible(); };
649 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
651 p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
661 delegate->locate(b, l, absOff);
663 assert(parent->
lvl + 1 == lvl);
670 return SUBI(wrapCrd, getAbsOff());
680 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
681 if (p && p->lvl == lvl)
682 crd =
SUBI(getAbsOff(), p->getAbsOff());
691 Value getMinCrd()
const {
return subSectMeta[0]; }
692 Value getAbsOff()
const {
return subSectMeta[1]; }
693 Value getNotEnd()
const {
return subSectMeta[2]; }
696 std::unique_ptr<SparseIterator> delegate;
699 const unsigned tupleSz;
701 Value maxTupleCnt, tupleCnt;
705 const Value subSectSz;
711 class SubSectIterator;
715 struct SubSectIterHelper {
716 explicit SubSectIterHelper(
const SubSectIterator &iter);
717 explicit SubSectIterHelper(
const NonEmptySubSectIterator &subSect);
726 const NonEmptySubSectIterator &subSect;
732 SubSectIterator(
const NonEmptySubSectIterator &subSect,
734 std::unique_ptr<SparseIterator> &&
wrap)
736 wrap->randomAccessible() ? 0 : 1),
737 subSect(subSect),
wrap(std::move(
wrap)), parent(parent), helper(*this) {
738 assert(subSect.tid == tid && subSect.lvl == lvl);
747 std::string getDebugInterfacePrefix()
const override {
748 return std::string(
"subsect<") +
wrap->getDebugInterfacePrefix() +
">";
752 if (!randomAccessible())
757 bool isBatchIterator()
const override {
return wrap->isBatchIterator(); }
758 bool randomAccessible()
const override {
return wrap->randomAccessible(); };
759 bool iteratableByFor()
const override {
return randomAccessible(); };
761 return subSect.subSectSz;
764 ValueRange getCurPosition()
const override {
return wrap->getCurPosition(); };
767 if (randomAccessible()) {
768 return ADDI(getCrd(), nxLvlTupleStart);
770 return ADDI(getCursor().back(), nxLvlTupleStart);
774 if (randomAccessible()) {
775 if (
auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
776 assert(p->lvl + 1 == lvl);
777 wrap->genInit(b, l, p);
779 nxLvlTupleStart =
MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
781 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
782 wrap->deserialize(subSect.delegate->serialize());
783 nxLvlTupleStart =
C_IDX(0);
787 assert(!randomAccessible());
788 assert(getCursor().size() ==
wrap->getCursor().size() + 1);
791 getMutCursorVals().back() =
C_IDX(0);
793 if (
auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
794 assert(p->lvl + 1 == lvl);
795 tupleId = p->getNxLvlTupleId(b, l);
797 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
800 nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
801 helper.deserializeFromTupleId(b, l, tupleId);
805 helper.locate(b, l, crd);
810 return helper.genNotEnd(b, l);
814 Value crd = helper.deref(b, l);
820 helper.forward(b, l);
821 assert(!randomAccessible());
822 assert(getCursor().size() ==
wrap->getCursor().size() + 1);
823 getMutCursorVals().back() =
ADDI(getCursor().back(),
C_IDX(1));
827 Value nxLvlTupleStart;
829 const NonEmptySubSectIterator &subSect;
830 std::unique_ptr<SparseIterator>
wrap;
833 SubSectIterHelper helper;
918 seek(ifOp.getResults());
923 auto whileOp = b.
create<scf::WhileOp>(
927 Value inBound =
CMPI(ult, ivs.front(), posHi);
928 auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound,
true);
932 b.setInsertionPointToStart(ifInBound.thenBlock());
935 Value isDup =
CMPI(eq, headCrd, tailCrd);
938 b.setInsertionPointToStart(ifInBound.elseBlock());
941 b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
954 Value crd = fromWrapCrd(b, l, wrapCrd);
956 Value notlegit =
CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
958 notlegit =
ORI(
CMPI(ult, wrapCrd, offset), notlegit);
960 notlegit =
ORI(
CMPI(uge, crd, size), notlegit);
968 Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
972 assert(r.size() == 1);
977 assert(!
wrap->randomAccessible());
981 Value crd = fromWrapCrd(b, l, wrapCrd);
983 return {
CMPI(ult, crd, size)};
985 assert(r.size() == 1);
990 assert(!randomAccessible());
1004 whileArgs.push_back(isFirst);
1005 auto whileOp = b.
create<scf::WhileOp>(
1010 assert(isFirst.size() == 1);
1017 genCrdNotLegitPredicate(b, l, wrapCrd);
1018 Value crd = fromWrapCrd(b, l, wrapCrd);
1020 ret =
ORI(ret, isFirst.front());
1023 b.
create<scf::ConditionOp>(l, cont.front(), ivs);
1028 wrap->forward(b, l);
1030 yieldVals.push_back(
constantI1(b, l,
false));
1035 linkNewScope(whileOp.getResults());
1039 SubSectIterHelper::SubSectIterHelper(
const NonEmptySubSectIterator &subSect)
1040 : subSect(subSect),
wrap(*subSect.delegate) {}
1042 SubSectIterHelper::SubSectIterHelper(
const SubSectIterator &iter)
1043 : subSect(iter.subSect),
wrap(*iter.
wrap) {}
1047 assert(!subSect.randomAccessible());
1048 wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1052 Value absCrd =
ADDI(crd, subSect.getAbsOff());
1053 wrap.locate(b, l, absCrd);
1057 assert(!
wrap.randomAccessible());
1061 Value crd =
SUBI(wrapCrd, subSect.getAbsOff());
1063 return {
CMPI(ult, crd, subSect.subSectSz)};
1065 assert(r.size() == 1);
1071 Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1076 return wrap.forward(b, l);
1079 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1082 SubSectIterHelper helper(*
this);
1083 if (!randomAccessible()) {
1087 iterArgs.push_back(
C_IDX(0));
1088 iterArgs.append(reduc.begin(), reduc.end());
1089 auto forEachLeaf = b.
create<scf::ForOp>(
1094 helper.deserializeFromTupleId(b, l, tupleId);
1096 Value cnt = iterArgs.front();
1100 helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1103 whileArgs.append(iterArgs.begin(), iterArgs.end());
1105 auto whileOp = b.create<scf::WhileOp>(
1109 helper.wrap.linkNewScope(ivs);
1110 b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1114 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1115 Value cnt = remIter.front();
1121 nxIter.append(userNx.begin(), userNx.end());
1124 ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1127 return forEachLeaf.
getResults().drop_front();
1130 assert(randomAccessible());
1135 assert(!parent || parent->lvl + 1 == lvl);
1136 delegate->genInit(b, l, parent);
1137 auto forOp = b.
create<scf::ForOp>(
1140 helper.locate(b, l, crd);
1147 if (isSubSectRoot()) {
1148 return visitDenseSubSect(b, l, parent, reduc);
1151 auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1152 assert(p->lvl + 1 == lvl);
1153 return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1159 if (!isSubSectRoot()) {
1160 assert(parent->lvl + 1 == lvl);
1161 if (randomAccessible()) {
1169 auto *p = cast<NonEmptySubSectIterator>(parent);
1180 assert(parent->
lvl + 1 == lvl && reduc.size() == 2);
1181 Value minCrd = reduc.front();
1182 Value tupleId = reduc.back();
1185 SubSectIterHelper helper(*
this);
1186 helper.wrap.genInit(b, l, parent);
1192 Value min = MINUI(crd, minCrd);
1198 storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1200 return {minCrd, tupleId};
1202 assert(result.size() == 2);
1203 tupleCnt = result.back();
1205 Value minCrd = result.front();
1208 seek({minCrd, absOff, notEnd});
1214 assert(isSubSectRoot());
1218 delegate->genInit(b, l, parent);
1219 if (randomAccessible()) {
1225 tupleCnt =
C_IDX(1);
1227 storeCursorVals(b, l, c0, delegate->serialize());
1230 b, l, *delegate, elseRet,
1233 return {crd, offset,
C_TRUE};
1240 assert(!randomAccessible());
1253 Value fastPathP =
CMPI(ugt, getMinCrd(), getAbsOff());
1254 auto ifOp = b.
create<scf::IfOp>(l, getCursor().getTypes(), fastPathP,
true);
1275 b, l, c0, tupleCnt, c1, loopArgs,
1278 Value tupleId = ivs.front();
1279 SubSectIterHelper helper(*
this);
1280 helper.deserializeFromTupleId(b, l, tupleId);
1283 b, l, *delegate, iterArgs,
1288 Value isMin =
CMPI(eq, crd, getMinCrd());
1289 delegate->forwardIf(b, l, isMin);
1291 auto ifIsMin = b.
create<scf::IfOp>(l, isMin,
false);
1293 storeCursorVals(b, l, tupleId, delegate->serialize());
1297 Value nxMin = iterArgs[0];
1308 scf::ForOp forOp = loopNest.
loops.front();
1311 Value nxMinCrd = forOp.getResult(0);
1312 Value nxNotEnd = forOp.getResult(1);
1317 Value nxMinCrd = ifOp.getResult(0);
1318 Value nxAbsOff = ifOp.getResult(1);
1319 Value nxNotEnd = ifOp.getResult(2);
1322 Value minAbsOff =
ADDI(getAbsOff(), c1);
1323 nxAbsOff = b.
create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1325 seek(
ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1327 Value crd = deref(b, l);
1328 nxNotEnd =
ANDI(nxNotEnd,
CMPI(ult, crd, upperBound(b, l)));
1330 seek(
ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1338 std::unique_ptr<SparseTensorLevel>
1340 unsigned tid,
Level lvl) {
1344 Value sz = stt.hasEncoding() ? b.
create<LvlOp>(l, t, lvl).getResult()
1345 : b.
create<tensor::DimOp>(l, t, lvl).getResult();
1348 case LevelFormat::Dense:
1349 return std::make_unique<DenseLevel>(tid, lvl, sz);
1350 case LevelFormat::Batch:
1351 return std::make_unique<BatchLevel>(tid, lvl, sz);
1352 case LevelFormat::Compressed: {
1355 return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
1357 case LevelFormat::LooseCompressed: {
1360 return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
1362 case LevelFormat::Singleton: {
1364 return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
1366 case LevelFormat::NOutOfM: {
1368 return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
1370 case LevelFormat::Undef:
1371 llvm_unreachable(
"undefined level format");
1373 llvm_unreachable(
"unrecognizable level format");
1376 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1379 auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1380 auto it = std::make_unique<TrivialIterator>(*stl);
1382 return std::make_pair(std::move(stl), std::move(it));
1385 std::unique_ptr<SparseIterator>
1388 std::unique_ptr<SparseIterator> ret;
1392 ret = std::make_unique<DedupIterator>(stl);
1394 ret = std::make_unique<TrivialIterator>(stl);
1396 ret->setSparseEmitStrategy(strategy);
1400 std::unique_ptr<SparseIterator>
1406 std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1407 ret->setSparseEmitStrategy(strategy);
1412 auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1414 return filter->wrap.get();
1420 std::unique_ptr<SparseIterator> &&delegate,
Value size,
unsigned stride,
1425 std::unique_ptr<SparseIterator> it =
1426 std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1427 std::move(delegate), size);
1432 it = std::make_unique<FilterIterator>(std::move(it),
C_IDX(0),
1433 C_IDX(stride), loopBound);
1435 it->setSparseEmitStrategy(strategy);
1448 std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1452 it = std::make_unique<FilterIterator>(std::move(it),
C_IDX(0),
1453 C_IDX(stride), loopBound);
1455 it->setSparseEmitStrategy(strategy);
bool isUnique(It begin, It end)
#define SELECT(c, lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
void updateCrd(Value crd)
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
SparseEmitStrategy emitStrategy
ValueRange getBatchCrds() const
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
ValueRange getCursor() const
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
void seek(ValueRange vals)
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
std::string toString() const
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
bool isUniqueLT(LevelType lt)
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over a non-empty subsection created by...
uint64_t Level
The type of level identifiers and level-ranks.
uint64_t getN(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::unique_ptr< SparseIterator > makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy)
Helper function to create a simple SparseIterator object that iterate over the SparseTensorLevel.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterate over a dense space specified...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context)
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.