MLIR 22.0.0git
SparseTensorIterator.cpp
Go to the documentation of this file.
1//===- SparseTensorIterator.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "CodegenUtils.h"
11
15
16using namespace mlir;
17using namespace mlir::sparse_tensor;
18using ValuePair = std::pair<Value, Value>;
19using ValueTuple = std::tuple<Value, Value, Value>;
20
21//===----------------------------------------------------------------------===//
22// File local helper functions/macros.
23//===----------------------------------------------------------------------===//
24#define CMPI(p, lhs, rhs) \
25 (arith::CmpIOp::create(b, l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26 .getResult())
27
28#define C_FALSE (constantI1(b, l, false))
29#define C_TRUE (constantI1(b, l, true))
30#define C_IDX(v) (constantIndex(b, l, (v)))
31#define YIELD(vs) (scf::YieldOp::create(b, l, (vs)))
32#define ADDI(lhs, rhs) (arith::AddIOp::create(b, l, (lhs), (rhs)).getResult())
33#define ORI(lhs, rhs) (arith::OrIOp::create(b, l, (lhs), (rhs)).getResult())
34#define ANDI(lhs, rhs) (arith::AndIOp::create(b, l, (lhs), (rhs)).getResult())
35#define SUBI(lhs, rhs) (arith::SubIOp::create(b, l, (lhs), (rhs)).getResult())
36#define MULI(lhs, rhs) (arith::MulIOp::create(b, l, (lhs), (rhs)).getResult())
37#define MINUI(lhs, rhs) (arith::MinUIOp::create(b, l, (lhs), (rhs)).getResult())
38#define REMUI(lhs, rhs) (arith::RemUIOp::create(b, l, (lhs), (rhs)).getResult())
39#define DIVUI(lhs, rhs) (arith::DivUIOp::create(b, l, (lhs), (rhs)).getResult())
40#define SELECT(c, lhs, rhs) \
41 (arith::SelectOp::create(b, l, (c), (lhs), (rhs)).getResult())
42
43//===----------------------------------------------------------------------===//
44// SparseTensorLevel derived classes.
45//===----------------------------------------------------------------------===//
46
47namespace {
48
49template <bool hasPosBuffer>
50class SparseLevel : public SparseTensorLevel {
51 // It is either an array of size 2 or size 1 depending on whether the sparse
52 // level requires a position array.
53 using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54 std::array<Value, 1>>;
55
56public:
57 SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58 BufferT buffers)
59 : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60
61 ValueRange getLvlBuffers() const override { return buffers; }
62
63 Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64 Value iv) const override {
65 SmallVector<Value> memCrd(batchPrefix);
66 memCrd.push_back(iv);
67 return genIndexLoad(b, l, getCrdBuf(), memCrd);
68 }
69
70protected:
71 template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72 Value getPosBuf() const {
73 return buffers[0];
74 }
75
76 Value getCrdBuf() const {
77 if constexpr (hasPosBuffer)
78 return buffers[1];
79 else
80 return buffers[0];
81 }
82
83 const BufferT buffers;
84};
85
86class DenseLevel : public SparseTensorLevel {
87public:
88 DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89 : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90
91 Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92 llvm_unreachable("locate random-accessible level instead");
93 }
94
95 ValueRange getLvlBuffers() const override { return {}; }
96
97 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98 ValueRange parentPos, Value inPadZone) const override {
99 assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100 assert(!inPadZone && "Not implemented");
101 Value p = parentPos.front();
102 Value posLo = MULI(p, lvlSize);
103 return {posLo, lvlSize};
104 }
105};
106
107class BatchLevel : public SparseTensorLevel {
108public:
109 BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110 : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111
112 Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113 llvm_unreachable("locate random-accessible level instead");
114 }
115
116 ValueRange getLvlBuffers() const override { return {}; }
117
118 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119 ValueRange parentPos, Value inPadZone) const override {
120 assert(!inPadZone && "Not implemented");
121 assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122 // No need to linearize the position for non-annotated tensors.
123 return {C_IDX(0), lvlSize};
124 }
125};
126
127class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128public:
129 CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130 Value posBuffer, Value crdBuffer)
131 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132
133 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134 ValueRange parentPos, Value inPadZone) const override {
135
136 assert(parentPos.size() == 1 &&
137 "compressed level must be the first non-unique level.");
138
139 auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140 Value p = parentPos.front();
141 SmallVector<Value> memCrd(batchPrefix);
142 memCrd.push_back(p);
143 Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144 memCrd.back() = ADDI(p, C_IDX(1));
145 Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146 return {pLo, pHi};
147 };
148
149 if (inPadZone == nullptr)
150 return loadRange();
151
152 SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153 scf::IfOp posRangeIf = scf::IfOp::create(b, l, types, inPadZone, true);
154 // True branch, returns a "fake" empty range [0, 0) if parent
155 // iterator is in pad zone.
156 b.setInsertionPointToStart(posRangeIf.thenBlock());
157
158 SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159 scf::YieldOp::create(b, l, emptyRange);
160
161 // False branch, returns the actual range.
162 b.setInsertionPointToStart(posRangeIf.elseBlock());
163 auto [pLo, pHi] = loadRange();
164 SmallVector<Value, 2> loadedRange{pLo, pHi};
165 scf::YieldOp::create(b, l, loadedRange);
166
167 b.setInsertionPointAfter(posRangeIf);
168 ValueRange posRange = posRangeIf.getResults();
169 return {posRange.front(), posRange.back()};
170 }
171}; // namespace
172
173class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174public:
175 LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176 Value posBuffer, Value crdBuffer)
177 : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178
179 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180 ValueRange parentPos, Value inPadZone) const override {
181 assert(parentPos.size() == 1 &&
182 "loose-compressed level must be the first non-unique level.");
183 assert(!inPadZone && "Not implemented");
184 SmallVector<Value> memCrd(batchPrefix);
185 Value p = parentPos.front();
186 p = MULI(p, C_IDX(2));
187 memCrd.push_back(p);
188 Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189 memCrd.back() = ADDI(p, C_IDX(1));
190 Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191 return {pLo, pHi};
192 }
193}; // namespace
194
195class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196public:
197 SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198 Value crdBuffer)
199 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200
201 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202 ValueRange parentPos, Value inPadZone) const override {
203 assert(parentPos.size() == 1 || parentPos.size() == 2);
204 assert(!inPadZone && "Not implemented");
205 Value p = parentPos.front();
206 Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207
208 if (segHi == nullptr)
209 return {p, ADDI(p, C_IDX(1))};
210 // Use the segHi as the loop upper bound.
211 return {p, segHi};
212 }
213
215 collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216 std::pair<Value, Value> parentRange) const override {
217 // Singleton level keeps the same range after collapsing.
218 return parentRange;
219 };
220};
221
222class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223public:
224 NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225 Value crdBuffer)
226 : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227
228 ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229 ValueRange parentPos, Value inPadZone) const override {
230 assert(parentPos.size() == 1 && isUnique() &&
231 "n:m level can not be non-unique.");
232 assert(!inPadZone && "Not implemented");
233 // Each n:m blk has exactly n specified elements.
234 auto n = getN(lt);
235 Value posLo = MULI(parentPos.front(), C_IDX(n));
236 return {posLo, ADDI(posLo, C_IDX(n))};
237 }
238};
239
240} // namespace
241
242//===----------------------------------------------------------------------===//
243// File local helpers
244//===----------------------------------------------------------------------===//
245
247 OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
249 builder) {
250 TypeRange ifRetTypes = elseRet.getTypes();
251 auto ifOp = scf::IfOp::create(b, l, ifRetTypes, it.genNotEnd(b, l), true);
252
253 b.setInsertionPointToStart(ifOp.thenBlock());
254 Value crd = it.deref(b, l);
255 scf::ValueVector ret = builder(b, l, crd);
256 YIELD(ret);
257
258 b.setInsertionPointToStart(ifOp.elseBlock());
259 YIELD(elseRet);
260
261 b.setInsertionPointAfter(ifOp);
262 return ifOp.getResults();
263}
264
265/// Generates code to compute the *absolute* offset of the slice based on the
266/// provide minimum coordinates in the slice.
267/// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268/// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269/// offset is the offset computed relative to the initial tensors T.
270///
271/// When isNonEmpty == true, the computed offset is meaningless and should not
272/// be used during runtime, the method generates code to return 0 currently in
273/// that case.
274///
275/// offset = minCrd >= size ? minCrd - size + 1 : 0;
277 Value size) {
278 Value geSize = CMPI(uge, minCrd, size);
279 // Compute minCrd - size + 1.
280 Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281 // This is the absolute offset related to the actual tensor.
282 return SELECT(geSize, mms, C_IDX(0));
283}
284
285//===----------------------------------------------------------------------===//
286// SparseIterator derived classes.
287//===----------------------------------------------------------------------===//
288
289namespace {
290
291// The iterator that traverses a concrete sparse tensor levels. High-level
292// abstract iterators wrap it to achieve more complex goals (such as collapsing
293// several levels). It also holds the common storage to hold the mlir::Values
294// for itself as well as for wrappers.
295class ConcreteIterator : public SparseIterator {
296protected:
297 ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298 unsigned cursorValCnt)
299 : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300 stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301 assert(getCursor().size() == cursorValCnt);
302 };
303
304public:
305 // For LLVM-style RTTI.
306 static bool classof(const SparseIterator *from) {
307 return from->kind == IterKind::kTrivial;
308 }
309
310 bool isBatchIterator() const override {
311 return stl.getLT().isa<LevelFormat::Batch>();
312 }
313 bool randomAccessible() const override {
314 return stl.getLT().hasDenseSemantic();
315 };
316 bool iteratableByFor() const override { return kind != IterKind::kDedup; };
317 Value upperBound(OpBuilder &b, Location l) const override {
318 return stl.getSize();
319 };
320
321protected:
322 const SparseTensorLevel &stl;
323 // Owner of the storage, all wrappers build on top of a concrete iterator
324 // share the same storage such that the iterator values are always
325 // synchronized.
326 SmallVector<Value> cursorValsStorage;
327};
328
329class TrivialIterator : public ConcreteIterator {
330public:
331 TrivialIterator(const SparseTensorLevel &stl)
332 : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333
334 TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335 Value posLo, Value posHi)
336 : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337 posHi(posHi) {
338 seek(posLo);
339 }
340
341 std::string getDebugInterfacePrefix() const override {
342 return std::string("trivial<") + stl.toString() + ">";
343 }
344 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345 return {b.getIndexType()};
346 }
347
348 SmallVector<Value> serialize() const override {
349 SmallVector<Value> ret;
350 ret.push_back(getItPos());
351 if (randomAccessible()) {
352 // Loop high is implicit (defined by `upperBound()`) for random-access
353 // iterator, but we need to memorize posLo for linearization.
354 ret.push_back(posLo);
355 } else {
356 ret.push_back(posHi);
357 }
358 return ret;
359 };
360
361 void deserialize(ValueRange vs) override {
362 assert(vs.size() == 2);
363 seek(vs.front());
364 if (randomAccessible())
365 posLo = vs.back();
366 else
367 posHi = vs.back();
368 };
369
370 void genInitImpl(OpBuilder &b, Location l,
371 const SparseIterator *parent) override;
372
373 ValuePair genForCond(OpBuilder &b, Location l) override {
374 if (randomAccessible())
375 return {deref(b, l), upperBound(b, l)};
376 return std::make_pair(getItPos(), posHi);
377 }
378
379 Value genNotEndImpl(OpBuilder &b, Location l) override {
380 // We used the first level bound as the bound the collapsed set of levels.
381 return CMPI(ult, getItPos(), posHi);
382 }
383
384 Value derefImpl(OpBuilder &b, Location l) override {
385 if (randomAccessible()) {
386 updateCrd(SUBI(getItPos(), posLo));
387 } else {
388 updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
389 }
390 return getCrd();
391 };
392
393 ValueRange forwardImpl(OpBuilder &b, Location l) override {
394 seek(ADDI(getItPos(), C_IDX(1)));
395 return getCursor();
396 }
397
398 ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399 Value curPos = getCursor().front();
400 Value nxPos = forward(b, l).front();
401 seek(SELECT(cond, nxPos, curPos));
402 return getCursor();
403 }
404
405 void locateImpl(OpBuilder &b, Location l, Value crd) override {
406 assert(randomAccessible());
407 // Seek to the linearized position.
408 seek(ADDI(crd, posLo));
409 updateCrd(crd);
410 if (isBatchIterator()) {
411 // If this is a batch iterator, also update the batch coordinate.
412 assert(batchCrds.size() > lvl);
413 batchCrds[lvl] = crd;
414 }
415 }
416
417 Value getItPos() const { return getCursor().front(); }
418 Value posLo, posHi;
419};
420
421class DedupIterator : public ConcreteIterator {
422private:
423 Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424
425public:
426 DedupIterator(const SparseTensorLevel &stl)
427 : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428 assert(!stl.isUnique());
429 }
430
431 DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432 Value posLo, Value posHi)
433 : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434 assert(!stl.isUnique());
435 seek({posLo, genSegmentHigh(b, l, posLo)});
436 }
437
438 // For LLVM-style RTTI.
439 static bool classof(const SparseIterator *from) {
440 return from->kind == IterKind::kDedup;
441 }
442
443 std::string getDebugInterfacePrefix() const override {
444 return std::string("dedup<") + stl.toString() + ">";
445 }
446 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447 return {b.getIndexType(), b.getIndexType()};
448 }
449
450 void genInitImpl(OpBuilder &b, Location l,
451 const SparseIterator *parent) override {
452 Value c0 = C_IDX(0);
453 ValueRange pPos = c0;
454
455 // If the parent iterator is a batch iterator, we also start from 0 (but
456 // on a different batch).
457 if (parent && !parent->isBatchIterator())
458 pPos = parent->getCurPosition();
459
460 Value posLo;
461 ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462 std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
463
464 seek({posLo, genSegmentHigh(b, l, posLo)});
465 }
466
467 SmallVector<Value> serialize() const override {
468 SmallVector<Value> ret;
469 ret.append(getCursor().begin(), getCursor().end());
470 ret.push_back(posHi);
471 return ret;
472 };
473 void deserialize(ValueRange vs) override {
474 assert(vs.size() == 3);
475 seek(vs.take_front(getCursor().size()));
476 posHi = vs.back();
477 };
478
479 Value genNotEndImpl(OpBuilder &b, Location l) override {
480 return CMPI(ult, getPos(), posHi);
481 }
482
483 Value derefImpl(OpBuilder &b, Location l) override {
484 updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
485 return getCrd();
486 };
487
488 ValueRange forwardImpl(OpBuilder &b, Location l) override {
489 Value nxPos = getSegHi(); // forward the position to the next segment.
490 seek({nxPos, genSegmentHigh(b, l, nxPos)});
491 return getCursor();
492 }
493
494 Value getPos() const { return getCursor()[0]; }
495 Value getSegHi() const { return getCursor()[1]; }
496
497 Value posHi;
498};
499
500// A util base-iterator that delegates all methods to the wrapped iterator.
501class SimpleWrapIterator : public SparseIterator {
502public:
503 SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504 unsigned extraCursorVal = 0)
505 : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506
507 void setSparseEmitStrategy(SparseEmitStrategy strategy) override {
508 wrap->setSparseEmitStrategy(strategy);
509 }
510
511 SparseEmitStrategy getSparseEmitStrategy() const override {
512 return wrap->getSparseEmitStrategy();
513 }
514
515 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
516 return wrap->getCursorValTypes(b);
517 }
518 bool isBatchIterator() const override { return wrap->isBatchIterator(); }
519 bool randomAccessible() const override { return wrap->randomAccessible(); };
520 bool iteratableByFor() const override { return wrap->iteratableByFor(); };
521
522 SmallVector<Value> serialize() const override { return wrap->serialize(); };
523 void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
524 ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
525 void genInitImpl(OpBuilder &b, Location l,
526 const SparseIterator *parent) override {
527 wrap->genInit(b, l, parent);
528 }
529 Value genNotEndImpl(OpBuilder &b, Location l) override {
530 return wrap->genNotEndImpl(b, l);
531 }
532 ValueRange forwardImpl(OpBuilder &b, Location l) override {
533 return wrap->forward(b, l);
534 };
535 Value upperBound(OpBuilder &b, Location l) const override {
536 return wrap->upperBound(b, l);
537 };
538
539 Value derefImpl(OpBuilder &b, Location l) override {
540 return wrap->derefImpl(b, l);
541 }
542
543 void locateImpl(OpBuilder &b, Location l, Value crd) override {
544 return wrap->locate(b, l, crd);
545 }
546
547 SparseIterator &getWrappedIterator() const { return *wrap; }
548
549protected:
550 std::unique_ptr<SparseIterator> wrap;
551};
552
553//
554// A filter iterator wrapped from another iterator. The filter iterator update
555// the wrapped iterator *in-place*.
556//
557class FilterIterator : public SimpleWrapIterator {
558 // Coorindate translation between crd loaded from the wrap iterator and the
559 // filter iterator.
560 Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
561 // crd = (wrapCrd - offset) / stride
562 return DIVUI(SUBI(wrapCrd, offset), stride);
563 }
564 Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
565 // wrapCrd = crd * stride + offset
566 return ADDI(MULI(crd, stride), offset);
567 }
568
569 Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
570
571 Value genShouldFilter(OpBuilder &b, Location l);
572
573public:
574 // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
575 // when crd always < size.
576 FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
577 Value stride, Value size)
578 : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
579 stride(stride), size(size) {}
580
581 // For LLVM-style RTTI.
582 static bool classof(const SparseIterator *from) {
583 return from->kind == IterKind::kFilter;
584 }
585
586 std::string getDebugInterfacePrefix() const override {
587 return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
588 }
589
590 bool iteratableByFor() const override { return randomAccessible(); };
591 Value upperBound(OpBuilder &b, Location l) const override { return size; };
592
593 void genInitImpl(OpBuilder &b, Location l,
594 const SparseIterator *parent) override {
595 wrap->genInit(b, l, parent);
596 if (!randomAccessible()) {
597 // TODO: we can skip this when stride == 1 and offset == 0, we can also
598 // use binary search here.
599 forwardIf(b, l, genShouldFilter(b, l));
600 } else {
601 // Else, locate to the slice.offset, which is the first coordinate
602 // included by the slice.
603 wrap->locate(b, l, offset);
604 }
605 }
606
607 Value genNotEndImpl(OpBuilder &b, Location l) override;
608
609 Value derefImpl(OpBuilder &b, Location l) override {
610 updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
611 return getCrd();
612 }
613
614 void locateImpl(OpBuilder &b, Location l, Value crd) override {
615 assert(randomAccessible());
616 wrap->locate(b, l, toWrapCrd(b, l, crd));
617 updateCrd(crd);
618 }
619
620 ValueRange forwardImpl(OpBuilder &b, Location l) override;
621
622 Value offset, stride, size;
623};
624
625//
626// A pad iterator wrapped from another iterator. The pad iterator updates
627// the wrapped iterator *in-place*.
628//
629class PadIterator : public SimpleWrapIterator {
630
631public:
632 PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
633 Value padHigh)
634 : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
635 wrap->randomAccessible() ? 1 : 0),
636 padLow(padLow), padHigh(padHigh) {}
637
638 // For LLVM-style RTTI.
639 static bool classof(const SparseIterator *from) {
640 return from->kind == IterKind::kPad;
641 }
642
643 std::string getDebugInterfacePrefix() const override {
644 return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
645 }
646
647 // Returns a pair of values for *upper*, *lower* bound respectively.
648 ValuePair genForCond(OpBuilder &b, Location l) override {
649 if (randomAccessible())
650 return {getCrd(), upperBound(b, l)};
651 return wrap->genForCond(b, l);
652 }
653
654 // For padded dense iterator, we append a `inPadZone: bool` in addition to
655 // values used by the wrapped iterator.
656 ValueRange getCurPosition() const override { return getCursor(); }
657
658 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
659 SmallVector<Type> ret = wrap->getCursorValTypes(b);
660 // Need an extra boolean value `inPadZone` for padded dense iterator.
661 if (randomAccessible())
662 ret.push_back(b.getI1Type());
663
664 return ret;
665 }
666
667 // The upper bound after padding becomes `size + padLow + padHigh`.
668 Value upperBound(OpBuilder &b, Location l) const override {
669 return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
670 };
671
672 // The pad_coord = coord + pad_lo
673 Value derefImpl(OpBuilder &b, Location l) override {
674 updateCrd(ADDI(wrap->deref(b, l), padLow));
675 return getCrd();
676 }
677
678 void locateImpl(OpBuilder &b, Location l, Value crd) override {
679 assert(randomAccessible());
680 wrap->locate(b, l, SUBI(crd, padLow));
681
682 // inPadZone = crd < padLow || crd >= size + padLow.
683 Value inPadLow = CMPI(ult, crd, padLow);
684 Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
685 getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
686
687 updateCrd(crd);
688 }
689
690 Value padLow, padHigh;
691};
692
693class NonEmptySubSectIterator : public SparseIterator {
694public:
695 using TraverseBuilder = llvm::function_ref<scf::ValueVector(
696 OpBuilder &, Location, const SparseIterator *, ValueRange)>;
697
698 NonEmptySubSectIterator(OpBuilder &b, Location l,
699 const SparseIterator *parent,
700 std::unique_ptr<SparseIterator> &&delegate,
701 Value subSectSz)
702 : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
703 parent(parent), delegate(std::move(delegate)),
704 tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
705 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
706 if (p == nullptr) {
707 // Extract subsections along the root level.
708 maxTupleCnt = C_IDX(1);
709 } else if (p->lvl == lvl) {
710 // Extract subsections along the same level.
711 maxTupleCnt = p->maxTupleCnt;
712 assert(false && "Not implemented.");
713 } else {
714 // Extract subsections along the previous level.
715 assert(p->lvl + 1 == lvl);
716 maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
717 }
718 // We don't need an extra buffer to find subsections on random-accessible
719 // levels.
720 if (randomAccessible())
721 return;
722 subSectPosBuf = allocSubSectPosBuf(b, l);
723 }
724
725 // For LLVM-style RTTI.
726 static bool classof(const SparseIterator *from) {
727 return from->kind == IterKind::kNonEmptySubSect;
728 }
729
730 std::string getDebugInterfacePrefix() const override {
731 return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
732 }
733 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
734 // minCrd, absolute offset, notEnd
735 return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
736 }
737
738 // The sliced pointer buffer is organized as:
739 // [[itVal0, itVal1, ..., pNx0],
740 // [itVal0, itVal1, ..., pNx0],
741 // ...]
742 Value allocSubSectPosBuf(OpBuilder &b, Location l) {
743 return memref::AllocaOp::create(
744 b, l,
745 MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
746 maxTupleCnt);
747 }
748
749 void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
750 Value start) const {
751 memref::StoreOp::create(b, l, start, subSectPosBuf,
752 ValueRange{tupleId, C_IDX(tupleSz)});
753 }
754
755 Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
756 return memref::LoadOp::create(b, l, subSectPosBuf,
757 ValueRange{tupleId, C_IDX(tupleSz)});
758 }
759
760 void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
761 ValueRange itVals) const {
762 assert(itVals.size() == tupleSz);
763 for (unsigned i = 0; i < tupleSz; i++) {
764 memref::StoreOp::create(b, l, itVals[i], subSectPosBuf,
765 ValueRange{tupleId, C_IDX(i)});
766 }
767 }
768
769 SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
770 Value tupleId) const {
771 SmallVector<Value> ret;
772 for (unsigned i = 0; i < tupleSz; i++) {
773 Value v = memref::LoadOp::create(b, l, subSectPosBuf,
774 ValueRange{tupleId, C_IDX(i)});
775 ret.push_back(v);
776 }
777 return ret;
778 }
779
780 bool isSubSectRoot() const {
781 return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
782 }
783
784 // Generate code that inflate the current subsection tree till the current
785 // level such that every leaf node is visited.
786 ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
787 TraverseBuilder builder) const;
788
789 bool isBatchIterator() const override { return delegate->isBatchIterator(); }
790 bool randomAccessible() const override {
791 return delegate->randomAccessible();
792 };
793 bool iteratableByFor() const override { return randomAccessible(); };
794 Value upperBound(OpBuilder &b, Location l) const override {
795 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
796 Value parentUB =
797 p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
798 return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
799 };
800
801 void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
802
803 void locateImpl(OpBuilder &b, Location l, Value crd) override {
804 Value absOff = crd;
805
806 if (isSubSectRoot())
807 delegate->locate(b, l, absOff);
808 else
809 assert(parent->lvl + 1 == lvl);
810
811 seek(ValueRange{absOff, absOff, C_TRUE});
812 updateCrd(crd);
813 }
814
815 Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
816 return SUBI(wrapCrd, getAbsOff());
817 }
818
819 Value genNotEndImpl(OpBuilder &b, Location l) override {
820 return getNotEnd();
821 };
822
823 Value derefImpl(OpBuilder &b, Location l) override {
824 // Use the relative offset to coiterate.
825 Value crd;
826 auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
827 if (p && p->lvl == lvl)
828 crd = SUBI(getAbsOff(), p->getAbsOff());
829 crd = getAbsOff();
830
831 updateCrd(crd);
832 return crd;
833 };
834
835 ValueRange forwardImpl(OpBuilder &b, Location l) override;
836
837 Value getMinCrd() const { return subSectMeta[0]; }
838 Value getAbsOff() const { return subSectMeta[1]; }
839 Value getNotEnd() const { return subSectMeta[2]; }
840
841 const SparseIterator *parent;
842 std::unique_ptr<SparseIterator> delegate;
843
844 // Number of values required to serialize the wrapped iterator.
845 const unsigned tupleSz;
846 // Max number of tuples, and the actual number of tuple.
847 Value maxTupleCnt, tupleCnt;
848 // The memory used to cache the tuple serialized from the wrapped iterator.
849 Value subSectPosBuf;
850
851 const Value subSectSz;
852
853 // minCrd, absolute offset, notEnd
854 SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
855};
856
857class SubSectIterator;
858
859// A wrapper that helps generating code to traverse a subsection, used
860// by both `NonEmptySubSectIterator`and `SubSectIterator`.
861struct SubSectIterHelper {
862 explicit SubSectIterHelper(const SubSectIterator &iter);
863 explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
864
865 // Delegate methods.
866 void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
867 void locate(OpBuilder &b, Location l, Value crd);
868 Value genNotEnd(OpBuilder &b, Location l);
869 Value deref(OpBuilder &b, Location l);
870 ValueRange forward(OpBuilder &b, Location l);
871
872 const NonEmptySubSectIterator &subSect;
873 SparseIterator &wrap;
874};
875
876class SubSectIterator : public SparseIterator {
877public:
878 SubSectIterator(const NonEmptySubSectIterator &subSect,
879 const SparseIterator &parent,
880 std::unique_ptr<SparseIterator> &&wrap)
881 : SparseIterator(IterKind::kSubSect, *wrap,
882 /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
883 subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
884 assert(subSect.tid == tid && subSect.lvl == lvl);
885 assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
886 };
887
888 // For LLVM-style RTTI.
889 static bool classof(const SparseIterator *from) {
890 return from->kind == IterKind::kSubSect;
891 }
892
893 std::string getDebugInterfacePrefix() const override {
894 return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
895 }
896 SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
897 SmallVector<Type> ret = wrap->getCursorValTypes(b);
898 if (!randomAccessible())
899 ret.push_back(b.getIndexType()); // The extra counter.
900 return ret;
901 }
902
903 bool isBatchIterator() const override { return wrap->isBatchIterator(); }
904 bool randomAccessible() const override { return wrap->randomAccessible(); };
905 bool iteratableByFor() const override { return randomAccessible(); };
906 Value upperBound(OpBuilder &b, Location l) const override {
907 return subSect.subSectSz;
908 }
909
910 ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
911
912 Value getNxLvlTupleId(OpBuilder &b, Location l) const {
913 if (randomAccessible()) {
914 return ADDI(getCrd(), nxLvlTupleStart);
915 };
916 return ADDI(getCursor().back(), nxLvlTupleStart);
917 }
918
919 void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
920 if (randomAccessible()) {
921 if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
922 assert(p->lvl + 1 == lvl);
923 wrap->genInit(b, l, p);
924 // Linearize the dense subsection index.
925 nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
926 } else {
927 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
928 wrap->deserialize(subSect.delegate->serialize());
929 nxLvlTupleStart = C_IDX(0);
930 }
931 return;
932 }
933 assert(!randomAccessible());
934 assert(getCursor().size() == wrap->getCursor().size() + 1);
935 // Extra counter that counts the number of actually visited coordinates in
936 // the sparse subsection.
937 getMutCursorVals().back() = C_IDX(0);
938 Value tupleId;
939 if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
940 assert(p->lvl + 1 == lvl);
941 tupleId = p->getNxLvlTupleId(b, l);
942 } else {
943 assert(subSect.lvl == lvl && subSect.isSubSectRoot());
944 tupleId = C_IDX(0);
945 }
946 nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
947 helper.deserializeFromTupleId(b, l, tupleId);
948 }
949
950 void locateImpl(OpBuilder &b, Location l, Value crd) override {
951 helper.locate(b, l, crd);
952 updateCrd(crd);
953 }
954
955 Value genNotEndImpl(OpBuilder &b, Location l) override {
956 return helper.genNotEnd(b, l);
957 }
958
959 Value derefImpl(OpBuilder &b, Location l) override {
960 Value crd = helper.deref(b, l);
961 updateCrd(crd);
962 return crd;
963 };
964
965 ValueRange forwardImpl(OpBuilder &b, Location l) override {
966 helper.forward(b, l);
967 assert(!randomAccessible());
968 assert(getCursor().size() == wrap->getCursor().size() + 1);
969 getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
970 return getCursor();
971 };
972
973 Value nxLvlTupleStart;
974
975 const NonEmptySubSectIterator &subSect;
976 std::unique_ptr<SparseIterator> wrap;
977 const SparseIterator &parent;
978
979 SubSectIterHelper helper;
980};
981
982} // namespace
983
984//===----------------------------------------------------------------------===//
985// SparseIterator derived classes implementation.
986//===----------------------------------------------------------------------===//
987
989 const SparseIterator *p) {
991 std::string prefix = getDebugInterfacePrefix();
992 Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
994 seek(begin->getResults());
995 return;
996 }
997 // Inherent batch coordinates from parents.
998 if (p)
999 inherentBatch(*p);
1000 // TODO: support lowering to function call.
1001 return genInitImpl(b, l, p);
1002}
1003
1006 std::string prefix = getDebugInterfacePrefix();
1007 Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1008 getCursor(), b.getI1Type());
1009 return notEnd->getResult(0);
1010 }
1011 // TODO: support lowering to function call.
1012 return genNotEndImpl(b, l);
1013}
1014
1017 std::string prefix = getDebugInterfacePrefix();
1019 args.push_back(crd);
1020 Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1022 seek(locate->getResults());
1023 updateCrd(crd);
1024 return;
1025 }
1026 return locateImpl(b, l, crd);
1027}
1028
1031 std::string prefix = getDebugInterfacePrefix();
1033 Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1034 getCursor(), b.getIndexType());
1035 updateCrd(deref->getResult(0));
1036 return getCrd();
1037 }
1038 return derefImpl(b, l);
1039}
1040
1042 assert(!randomAccessible());
1044 std::string prefix = getDebugInterfacePrefix();
1045 Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1047 seek(next->getResults());
1048 return getCursor();
1049 }
1050 return forwardImpl(b, l);
1051}
1052
1054 auto ifOp = scf::IfOp::create(b, l, getCursor().getTypes(), cond, true);
1055 // Generate else branch first, otherwise iterator values will be updated by
1056 // `forward()`.
1057 b.setInsertionPointToStart(ifOp.elseBlock());
1058 YIELD(getCursor());
1059
1060 b.setInsertionPointToStart(ifOp.thenBlock());
1061 YIELD(forward(b, l));
1062
1063 b.setInsertionPointAfter(ifOp);
1064 seek(ifOp.getResults());
1065 return getCursor();
1066}
1067
1068Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1069 auto whileOp = scf::WhileOp::create(
1070 b, l, pos.getType(), pos,
1071 /*beforeBuilder=*/
1072 [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1073 Value inBound = CMPI(ult, ivs.front(), posHi);
1074 auto ifInBound = scf::IfOp::create(b, l, b.getI1Type(), inBound, true);
1075 {
1076 OpBuilder::InsertionGuard guard(b);
1077 // If in bound, load the next coordinates and check duplication.
1078 b.setInsertionPointToStart(ifInBound.thenBlock());
1079 Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1080 Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1081 Value isDup = CMPI(eq, headCrd, tailCrd);
1082 YIELD(isDup);
1083 // Else, the position is out of bound, yield false.
1084 b.setInsertionPointToStart(ifInBound.elseBlock());
1085 YIELD(constantI1(b, l, false));
1086 }
1087 scf::ConditionOp::create(b, l, ifInBound.getResults()[0], ivs);
1088 },
1089 /*afterBuilder=*/
1090 [](OpBuilder &b, Location l, ValueRange ivs) {
1091 Value nxPos = ADDI(ivs[0], C_IDX(1));
1092 YIELD(nxPos);
1093 });
1094 // Return the segment high.
1095 return whileOp.getResult(0);
1096}
1097
1098Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1099 Value wrapCrd) {
1100 Value crd = fromWrapCrd(b, l, wrapCrd);
1101 // Test whether the coordinate is on stride.
1102 Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1103 // Test wrapCrd < offset
1104 notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1105 // Test crd >= length
1106 notlegit = ORI(CMPI(uge, crd, size), notlegit);
1107 return notlegit;
1108}
1109
1110Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1111 auto r = genWhenInBound(
1112 b, l, *wrap, C_FALSE,
1113 [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1114 Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1115 return {notLegit};
1116 });
1117 return llvm::getSingleElement(r);
1118}
1119
1120Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1121 assert(!wrap->randomAccessible());
1122 auto r = genWhenInBound(
1123 b, l, *wrap, C_FALSE,
1124 [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1125 Value crd = fromWrapCrd(b, l, wrapCrd);
1126 // crd < size
1127 return {CMPI(ult, crd, size)};
1128 });
1129 return llvm::getSingleElement(r);
1130}
1131
1132ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1133 assert(!randomAccessible());
1134 // Generates
1135 //
1136 // bool isFirst = true;
1137 // while !it.end() && (!legit(*it) || isFirst)
1138 // wrap ++;
1139 // isFirst = false;
1140 //
1141 // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1142 // flag here because `wrap++` might have a complex implementation (e.g., to
1143 // forward a subsection).
1144 Value isFirst = constantI1(b, l, true);
1145
1146 SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1147 whileArgs.push_back(isFirst);
1148 auto whileOp = scf::WhileOp::create(
1149 b, l, ValueRange(whileArgs).getTypes(), whileArgs,
1150 /*beforeBuilder=*/
1151 [this](OpBuilder &b, Location l, ValueRange ivs) {
1152 ValueRange isFirst = linkNewScope(ivs);
1153 scf::ValueVector cont =
1155 [this, isFirst](OpBuilder &b, Location l,
1156 Value wrapCrd) -> scf::ValueVector {
1157 // crd < size && !legit();
1158 Value notLegit =
1159 genCrdNotLegitPredicate(b, l, wrapCrd);
1160 Value crd = fromWrapCrd(b, l, wrapCrd);
1161 Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1162 ret = ORI(ret, llvm::getSingleElement(isFirst));
1163 return {ret};
1164 });
1165 scf::ConditionOp::create(b, l, cont.front(), ivs);
1166 },
1167 /*afterBuilder=*/
1168 [this](OpBuilder &b, Location l, ValueRange ivs) {
1169 linkNewScope(ivs);
1170 wrap->forward(b, l);
1171 SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1172 yieldVals.push_back(constantI1(b, l, false));
1173 YIELD(yieldVals);
1174 });
1175
1176 b.setInsertionPointAfter(whileOp);
1177 linkNewScope(whileOp.getResults());
1178 return getCursor();
1179}
1180
1181SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1182 : subSect(subSect), wrap(*subSect.delegate) {}
1183
1184SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1185 : subSect(iter.subSect), wrap(*iter.wrap) {}
1186
1187void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1188 Value tupleId) {
1189 assert(!subSect.randomAccessible());
1190 wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1191}
1192
1193void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1194 Value absCrd = ADDI(crd, subSect.getAbsOff());
1195 wrap.locate(b, l, absCrd);
1196}
1197
1198Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1199 assert(!wrap.randomAccessible());
1200 auto r = genWhenInBound(
1201 b, l, wrap, C_FALSE,
1202 [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1203 Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1204 // crd < size
1205 return {CMPI(ult, crd, subSect.subSectSz)};
1206 });
1207 return llvm::getSingleElement(r);
1208}
1209
1210Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1211 Value wrapCrd = wrap.deref(b, l);
1212 Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1213 return crd;
1214}
1215
1216ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1217 return wrap.forward(b, l);
1218}
1219
1220ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1221 OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1222 // Set up the helper to help traverse a sparse subsection.
1223 SubSectIterHelper helper(*this);
1224 if (!randomAccessible()) {
1225 // The subsection tree have been expanded till the level and cached,
1226 // traverse all the leaves and expanded to the next level.
1227 SmallVector<Value> iterArgs;
1228 iterArgs.push_back(C_IDX(0));
1229 iterArgs.append(reduc.begin(), reduc.end());
1230 auto forEachLeaf = scf::ForOp::create(
1231 b, l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1232 [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1233 ValueRange iterArgs) {
1234 // Deserialize the iterator at the cached position (tupleId).
1235 helper.deserializeFromTupleId(b, l, tupleId);
1236
1237 Value cnt = iterArgs.front();
1238 // Record the number of leaf nodes included in the subsection.
1239 // The number indicates the starting tupleId for the next level that
1240 // is corresponding to the current node.
1241 helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1242
1243 SmallVector<Value> whileArgs(helper.wrap.getCursor());
1244 whileArgs.append(iterArgs.begin(), iterArgs.end());
1245
1246 auto whileOp = scf::WhileOp::create(
1247 b, l, ValueRange(whileArgs).getTypes(), whileArgs,
1248 /*beforeBuilder=*/
1249 [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1250 helper.wrap.linkNewScope(ivs);
1251 scf::ConditionOp::create(b, l, helper.genNotEnd(b, l), ivs);
1252 },
1253 /*afterBuilder=*/
1254 [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1255 ValueRange remIter = helper.wrap.linkNewScope(ivs);
1256 Value cnt = remIter.front();
1257 ValueRange userIter = remIter.drop_front();
1258 scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1259
1260 SmallVector<Value> nxIter = helper.forward(b, l);
1261 nxIter.push_back(ADDI(cnt, C_IDX(1)));
1262 nxIter.append(userNx.begin(), userNx.end());
1263 YIELD(nxIter);
1264 });
1265 ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1266 YIELD(res);
1267 });
1268 return forEachLeaf.getResults().drop_front();
1269 }
1270
1271 assert(randomAccessible());
1272 // Helper lambda that traverse the current dense subsection range.
1273 auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1274 const SparseIterator *parent,
1275 ValueRange reduc) {
1276 assert(!parent || parent->lvl + 1 == lvl);
1277 delegate->genInit(b, l, parent);
1278 auto forOp = scf::ForOp::create(
1279 b, l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1280 [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1281 helper.locate(b, l, crd);
1282 scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1283 YIELD(nx);
1284 });
1285 return forOp.getResults();
1286 };
1287
1288 if (isSubSectRoot()) {
1289 return visitDenseSubSect(b, l, parent, reduc);
1290 }
1291 // Else, this is not the root, recurse until root.
1292 auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1293 assert(p->lvl + 1 == lvl);
1294 return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1295}
1296
1297void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1298 const SparseIterator *parent) {
1299
1300 if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1301 batchCrds.resize(stl.lvl + 1, nullptr);
1302
1303 Value c0 = C_IDX(0);
1304 ValueRange pPos = c0;
1305 Value inPadZone = nullptr;
1306 // If the parent iterator is a batch iterator, we also start from 0 (but
1307 // on a different batch).
1308 if (parent && !parent->isBatchIterator()) {
1309 pPos = parent->getCurPosition();
1310 if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1311 // A padded dense iterator create "sparse" padded zone, which need to be
1312 // handled specially.
1313 inPadZone = pPos.back();
1314 pPos = pPos.drop_back();
1315 }
1316 }
1317
1318 ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1319 std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1320 // Seek to the lowest position.
1321 seek(posLo);
1322}
1323
1324void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1325 const SparseIterator *) {
1326 Value c0 = C_IDX(0);
1327 if (!isSubSectRoot()) {
1328 assert(parent->lvl + 1 == lvl);
1329 if (randomAccessible()) {
1330 // We can not call wrap->genInit() here to initialize the wrapped
1331 // iterator, because the parent of the curent iterator is still
1332 // unresolved.
1333 seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1334 return;
1335 }
1336
1337 auto *p = cast<NonEmptySubSectIterator>(parent);
1338 SmallVector<Value, 3> reduc = {
1339 C_IDX(-1), // minCrd (max signless integer)
1340 c0, // tupleId
1341 };
1342
1343 // Expand the subsection tree from the parent level to the current level.
1344 ValueRange result = p->inflateSubSectTree(
1345 b, l, reduc,
1346 [this](OpBuilder &b, Location l, const SparseIterator *parent,
1347 ValueRange reduc) -> scf::ValueVector {
1348 assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1349 Value minCrd = reduc.front();
1350 Value tupleId = reduc.back();
1351
1352 // Initialize the subsection range.
1353 SubSectIterHelper helper(*this);
1354 helper.wrap.genInit(b, l, parent);
1355
1356 // Update minCrd.
1357 minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1358 [minCrd](OpBuilder &b, Location l,
1359 Value crd) -> scf::ValueVector {
1360 Value min = MINUI(crd, minCrd);
1361 return {min};
1362 })
1363 .front();
1364
1365 // Cache the sparse range.
1366 storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1367 tupleId = ADDI(tupleId, C_IDX(1));
1368 return {minCrd, tupleId};
1369 });
1370 assert(result.size() == 2);
1371 tupleCnt = result.back();
1372
1373 Value minCrd = result.front();
1374 Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1375 Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1376 seek({minCrd, absOff, notEnd});
1377 return;
1378 }
1379
1380 // This is the root level of the subsection, which means that it is resolved
1381 // to one node.
1382 assert(isSubSectRoot());
1383
1384 // Initialize the position, the position marks the *lower bound* of the
1385 // subRange. The higher bound is determined by the size of the subsection.
1386 delegate->genInit(b, l, parent);
1387 if (randomAccessible()) {
1388 seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1389 return;
1390 }
1391
1392 // Only have one root node.
1393 tupleCnt = C_IDX(1);
1394 // Cache the sparse range.
1395 storeCursorVals(b, l, c0, delegate->serialize());
1396 SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1397 auto meta = genWhenInBound(
1398 b, l, *delegate, elseRet,
1399 [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1400 Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1401 return {crd, offset, C_TRUE};
1402 });
1403
1404 seek(meta);
1405}
1406
1407ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1408 assert(!randomAccessible());
1409 Value c0 = C_IDX(0), c1 = C_IDX(1);
1410 // Forward to the next non empty slice by generating
1411 //
1412 // if (minCrd > offset) {
1413 // offset += 1
1414 // } else {
1415 // minCrd = nextMinInSlice();
1416 // offset = minCrd - size + 1;
1417 // }
1418 //
1419 // if (offset + size > parents.size)
1420 // isNonEmpty = false;
1421 Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1422 auto ifOp = scf::IfOp::create(b, l, getCursor().getTypes(), fastPathP, true);
1423 {
1424 OpBuilder::InsertionGuard guard(b);
1425 // Take the fast path
1426 // if (minCrd > offset)
1427 // offset += 1
1428 b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1429 Value nxOffset = ADDI(getAbsOff(), c1);
1430 YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1431
1432 // else /*minCrd == offset*/ {
1433 // for (i = 0; i < tupleCnt; i++) {
1434 // wrap->deserialize(pos[i]);
1435 // minCrd=min(minCrd, *wrap);
1436 // }
1437 // offset = minCrd - size + 1;
1438 // }
1439 b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1440 SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1441 C_FALSE}; // isNotEnd
1442 auto loopNest = scf::buildLoopNest(
1443 b, l, c0, tupleCnt, c1, loopArgs,
1444 [this](OpBuilder &b, Location l, ValueRange ivs,
1445 ValueRange iterArgs) -> scf::ValueVector {
1446 Value tupleId = ivs.front();
1447 SubSectIterHelper helper(*this);
1448 helper.deserializeFromTupleId(b, l, tupleId);
1449
1450 return genWhenInBound(
1451 b, l, *delegate, /*elseRet=*/iterArgs,
1452 [this, iterArgs, tupleId](OpBuilder &b, Location l,
1453 Value crd) -> scf::ValueVector {
1454 // if coord == minCrd
1455 // wrap->forward();
1456 Value isMin = CMPI(eq, crd, getMinCrd());
1457 delegate->forwardIf(b, l, isMin);
1458 // Update the forwarded iterator values if needed.
1459 auto ifIsMin = scf::IfOp::create(b, l, isMin, false);
1460 b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1461 storeCursorVals(b, l, tupleId, delegate->serialize());
1462 b.setInsertionPointAfter(ifIsMin);
1463 // if (!wrap.end())
1464 // yield(min(nxMinCrd, *wrap), true)
1465 Value nxMin = iterArgs[0];
1466 return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1467 [nxMin](OpBuilder &b, Location l,
1468 Value crd) -> scf::ValueVector {
1469 Value nx = arith::MinUIOp::create(
1470 b, l, crd, nxMin);
1471 return {nx, C_TRUE};
1472 });
1473 });
1474 });
1475
1476 scf::ForOp forOp = loopNest.loops.front();
1477 b.setInsertionPointAfter(forOp);
1478
1479 Value nxMinCrd = forOp.getResult(0);
1480 Value nxNotEnd = forOp.getResult(1);
1481 Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1482 YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1483 }
1484
1485 Value nxMinCrd = ifOp.getResult(0);
1486 Value nxAbsOff = ifOp.getResult(1);
1487 Value nxNotEnd = ifOp.getResult(2);
1488
1489 // We should at least forward the offset by one.
1490 Value minAbsOff = ADDI(getAbsOff(), c1);
1491 nxAbsOff = arith::MaxUIOp::create(b, l, minAbsOff, nxAbsOff);
1492
1493 seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1494 // The coordinate should not exceeds the space upper bound.
1495 Value crd = deref(b, l);
1496 nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1497
1498 seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1499 return getCursor();
1500}
1501
1502//===----------------------------------------------------------------------===//
1503// SparseIterationSpace Implementation
1504//===----------------------------------------------------------------------===//
1505
1507 Location l, OpBuilder &b, Value t, unsigned tid,
1508 std::pair<Level, Level> lvlRange, ValueRange parentPos)
1509 : lvls() {
1510 auto [lvlLo, lvlHi] = lvlRange;
1511
1512 Value c0 = C_IDX(0);
1513 if (parentPos.empty())
1514 parentPos = c0;
1515
1516 for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1517 lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1518
1519 bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1520 for (auto &lvl : getLvlRef().drop_front())
1521 bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1522}
1523
1525 IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1526 // Reconstruct every sparse tensor level.
1528 for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1529 unsigned bufferCnt = 0;
1530 if (lt.isWithPosLT())
1531 bufferCnt++;
1532 if (lt.isWithCrdLT())
1533 bufferCnt++;
1534 // Sparse tensor buffers.
1535 ValueRange buffers = values.take_front(bufferCnt);
1536 values = values.drop_front(bufferCnt);
1537
1538 // Level size.
1539 Value sz = values.front();
1540 values = values.drop_front();
1541 space.lvls.push_back(
1542 makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1543 }
1544 // Two bounds.
1545 space.bound = std::make_pair(values[0], values[1]);
1546 values = values.drop_front(2);
1547
1548 // Must have consumed all values.
1549 assert(values.empty());
1550 return space;
1551}
1552
1553std::unique_ptr<SparseIterator>
1557
1558//===----------------------------------------------------------------------===//
1559// SparseIterator factory functions.
1560//===----------------------------------------------------------------------===//
1561
1562/// Helper function to create a TensorLevel object from given `tensor`.
1563std::unique_ptr<SparseTensorLevel>
1565 unsigned t, Level l) {
1566 assert(lt.getNumBuffer() == b.size());
1567 switch (lt.getLvlFmt()) {
1568 case LevelFormat::Dense:
1569 return std::make_unique<DenseLevel>(t, l, sz);
1570 case LevelFormat::Batch:
1571 return std::make_unique<BatchLevel>(t, l, sz);
1573 return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1575 return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1577 return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1579 return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1580 case LevelFormat::Undef:
1581 llvm_unreachable("undefined level format");
1582 }
1583 llvm_unreachable("unrecognizable level format");
1584}
1585
1586std::unique_ptr<SparseTensorLevel>
1588 unsigned tid, Level lvl) {
1589 auto stt = getSparseTensorType(t);
1590
1591 LevelType lt = stt.getLvlType(lvl);
1592 Value sz = stt.hasEncoding()
1593 ? LvlOp::create(b, l, t, lvl).getResult()
1594 : tensor::DimOp::create(b, l, t, lvl).getResult();
1595
1596 SmallVector<Value, 2> buffers;
1597 if (lt.isWithPosLT()) {
1598 Value pos = ToPositionsOp::create(b, l, t, lvl);
1599 buffers.push_back(pos);
1600 }
1601 if (lt.isWithCrdLT()) {
1602 Value pos = ToCoordinatesOp::create(b, l, t, lvl);
1603 buffers.push_back(pos);
1604 }
1605 return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
1606}
1607
1608std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1609sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1610 SparseEmitStrategy strategy) {
1611 auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1612 auto it = std::make_unique<TrivialIterator>(*stl);
1613 it->setSparseEmitStrategy(strategy);
1614 return std::make_pair(std::move(stl), std::move(it));
1615}
1616
1617std::unique_ptr<SparseIterator>
1619 const SparseIterationSpace &iterSpace) {
1620 // assert(iterSpace.getSpaceDim() == 1);
1621 std::unique_ptr<SparseIterator> ret;
1622 if (!iterSpace.isUnique()) {
1623 // We always dedupliate the non-unique level, but we should optimize it away
1624 // if possible.
1625 ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1626 iterSpace.getBoundLo(),
1627 iterSpace.getBoundHi());
1628 } else {
1629 ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1630 iterSpace.getBoundLo(),
1631 iterSpace.getBoundHi());
1632 }
1633 ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1634 return ret;
1635}
1636
1637std::unique_ptr<SparseIterator>
1639 SparseEmitStrategy strategy) {
1640 std::unique_ptr<SparseIterator> ret;
1641 if (!isUniqueLT(stl.getLT())) {
1642 // We always dedupliate the non-unique level, but we should optimize it away
1643 // if possible.
1644 ret = std::make_unique<DedupIterator>(stl);
1645 } else {
1646 ret = std::make_unique<TrivialIterator>(stl);
1647 }
1648 ret->setSparseEmitStrategy(strategy);
1649 return ret;
1650}
1651
1652std::unique_ptr<SparseIterator>
1653sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1654 Value offset, Value stride, Value size,
1655 SparseEmitStrategy strategy) {
1656
1657 auto ret =
1658 std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1659 ret->setSparseEmitStrategy(strategy);
1660 return ret;
1661}
1662
1663std::unique_ptr<SparseIterator>
1664sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1665 Value padLow, Value padHigh,
1666 SparseEmitStrategy strategy) {
1667 auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1668 ret->setSparseEmitStrategy(strategy);
1669 return ret;
1670}
1671
1673 auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1674 if (filter)
1675 return &filter->getWrappedIterator();
1676 return it;
1677}
1678
1679std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1680 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1681 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1682 SparseEmitStrategy strategy) {
1683
1684 // Try unwrap the NonEmptySubSectIterator from a filter parent.
1685 parent = tryUnwrapFilter(parent);
1686 std::unique_ptr<SparseIterator> it =
1687 std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1688 std::move(delegate), size);
1689
1690 if (stride != 1) {
1691 // TODO: We can safely skip bound checking on sparse levels, but for dense
1692 // iteration space, we need the bound to infer the dense loop range.
1693 it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1694 C_IDX(stride), /*size=*/loopBound);
1695 }
1696 it->setSparseEmitStrategy(strategy);
1697 return it;
1698}
1699
1700std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1701 OpBuilder &b, Location l, const SparseIterator &subSectIter,
1702 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1703 Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1704
1705 // This must be a subsection iterator or a filtered subsection iterator.
1706 auto &subSect =
1707 llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1708
1709 std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1710 subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1711
1712 if (stride != 1) {
1713 it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1714 C_IDX(stride), /*size=*/loopBound);
1715 }
1716 it->setSparseEmitStrategy(strategy);
1717 return it;
1718}
1719
1720#undef CMPI
1721#undef C_IDX
1722#undef YIELD
1723#undef ADDI
1724#undef ANDI
1725#undef SUBI
1726#undef MULI
1727#undef SELECT
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define SUBI(lhs, rhs)
#define C_IDX(v)
#define ANDI(lhs, rhs)
#define CMPI(p, l, r)
#define YIELD(vs)
#define ADDI(lhs, rhs)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
#define SELECT(c, lhs, rhs)
#define C_FALSE
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
#define SUBI(lhs, rhs)
#define MULI(lhs, rhs)
#define C_IDX(v)
std::tuple< Value, Value, Value > ValueTuple
#define C_TRUE
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
#define YIELD(vs)
#define ORI(lhs, rhs)
#define DIVUI(lhs, rhs)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
#define ADDI(lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
const SparseTensorLevel & getLastLvl() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value derefImpl(OpBuilder &b, Location l)=0
virtual void setSparseEmitStrategy(SparseEmitStrategy strategy)
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual SparseEmitStrategy getSparseEmitStrategy() const
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:837
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t getN(LevelType lt)
Definition Enums.h:442
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
uint64_t Level
The type of level identifiers and level-ranks.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
LogicalResult serialize(ModuleOp moduleOp, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V moduleOp and writes to binary.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options={})
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition Passes.h:52
LoopVector loops
Definition SCF.h:67
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238
constexpr unsigned getNumBuffer() const
Definition Enums.h:360
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
Definition Enums.h:320
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
Definition Enums.h:348
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.
Definition Enums.h:354