MLIR  19.0.0git
SparseTensorIterator.cpp
Go to the documentation of this file.
1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11 
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20 
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs) \
25  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26  .getResult())
27 
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41  (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42 
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51  // It is either an array of size 2 or size 1 depending on whether the sparse
52  // level requires a position array.
53  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54  std::array<Value, 1>>;
55 
56 public:
57  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58  BufferT buffers)
59  : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60 
61  ValueRange getLvlBuffers() const override { return buffers; }
62 
63  Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64  Value iv) const override {
65  SmallVector<Value> memCrd(batchPrefix);
66  memCrd.push_back(iv);
67  return genIndexLoad(b, l, getCrdBuf(), memCrd);
68  }
69 
70 protected:
71  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72  Value getPosBuf() const {
73  return buffers[0];
74  }
75 
76  Value getCrdBuf() const {
77  if constexpr (hasPosBuffer)
78  return buffers[1];
79  else
80  return buffers[0];
81  }
82 
83  const BufferT buffers;
84 };
85 
86 class DenseLevel : public SparseTensorLevel {
87 public:
88  DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89  : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90 
91  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92  llvm_unreachable("locate random-accessible level instead");
93  }
94 
95  ValueRange getLvlBuffers() const override { return {}; }
96 
97  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98  ValueRange parentPos) const override {
99  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100  Value p = parentPos.front();
101  Value posLo = MULI(p, lvlSize);
102  return {posLo, lvlSize};
103  }
104 };
105 
106 class BatchLevel : public SparseTensorLevel {
107 public:
108  BatchLevel(unsigned tid, Level lvl, Value lvlSize)
109  : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
110 
111  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
112  llvm_unreachable("locate random-accessible level instead");
113  }
114 
115  ValueRange getLvlBuffers() const override { return {}; }
116 
117  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
118  ValueRange parentPos) const override {
119  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
120  // No need to linearize the position for non-annotated tensors.
121  return {C_IDX(0), lvlSize};
122  }
123 };
124 
125 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
126 public:
127  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
128  Value posBuffer, Value crdBuffer)
129  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
130 
131  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
132  ValueRange parentPos) const override {
133 
134  assert(parentPos.size() == 1 &&
135  "compressed level must be the first non-unique level.");
136  Value p = parentPos.front();
137 
138  SmallVector<Value> memCrd(batchPrefix);
139  memCrd.push_back(p);
140  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
141  memCrd.back() = ADDI(p, C_IDX(1));
142  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
143  return {pLo, pHi};
144  }
145 };
146 
147 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
148 public:
149  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
150  Value posBuffer, Value crdBuffer)
151  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
152 
153  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
154  ValueRange parentPos) const override {
155  assert(parentPos.size() == 1 &&
156  "loose-compressed level must be the first non-unique level.");
157  SmallVector<Value> memCrd(batchPrefix);
158  Value p = parentPos.front();
159  p = MULI(p, C_IDX(2));
160  memCrd.push_back(p);
161  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
162  memCrd.back() = ADDI(p, C_IDX(1));
163  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
164  return {pLo, pHi};
165  }
166 };
167 
168 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
169 public:
170  SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
171  Value crdBuffer)
172  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
173 
174  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
175  ValueRange parentPos) const override {
176  assert(parentPos.size() == 1 || parentPos.size() == 2);
177  Value p = parentPos.front();
178  Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
179 
180  if (segHi == nullptr)
181  return {p, ADDI(p, C_IDX(1))};
182  // Use the segHi as the loop upper bound.
183  return {p, segHi};
184  }
185 };
186 
187 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
188 public:
189  NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
190  Value crdBuffer)
191  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
192 
193  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
194  ValueRange parentPos) const override {
195  assert(parentPos.size() == 1 && isUnique() &&
196  "n:m level can not be non-unique.");
197  // Each n:m blk has exactly n specified elements.
198  auto n = getN(lt);
199  Value posLo = MULI(parentPos.front(), C_IDX(n));
200  return {posLo, ADDI(posLo, C_IDX(n))};
201  }
202 };
203 
204 } // namespace
205 
206 //===----------------------------------------------------------------------===//
207 // File local helpers
208 //===----------------------------------------------------------------------===//
209 
211  OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
213  builder) {
214  TypeRange ifRetTypes = elseRet.getTypes();
215  auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
216 
217  b.setInsertionPointToStart(ifOp.thenBlock());
218  Value crd = it.deref(b, l);
219  scf::ValueVector ret = builder(b, l, crd);
220  YIELD(ret);
221 
222  b.setInsertionPointToStart(ifOp.elseBlock());
223  YIELD(elseRet);
224 
225  b.setInsertionPointAfter(ifOp);
226  return ifOp.getResults();
227 }
228 
229 /// Generates code to compute the *absolute* offset of the slice based on the
230 /// provide minimum coordinates in the slice.
231 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
232 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
233 /// offset is the offset computed relative to the initial tensors T.
234 ///
235 /// When isNonEmpty == true, the computed offset is meaningless and should not
236 /// be used during runtime, the method generates code to return 0 currently in
237 /// that case.
238 ///
239 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
241  Value size) {
242  Value geSize = CMPI(uge, minCrd, size);
243  // Compute minCrd - size + 1.
244  Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
245  // This is the absolute offset related to the actual tensor.
246  return SELECT(geSize, mms, C_IDX(0));
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // SparseIterator derived classes.
251 //===----------------------------------------------------------------------===//
252 
253 namespace {
254 
255 // The iterator that traverses a concrete sparse tensor levels. High-level
256 // abstract iterators wrap it to achieve more complex goals (such as collapsing
257 // several levels). It also holds the common storage to hold the mlir::Values
258 // for itself as well as for wrappers.
259 class ConcreteIterator : public SparseIterator {
260 protected:
261  ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
262  unsigned cursorValCnt)
263  : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
264  stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
265  assert(getCursor().size() == cursorValCnt);
266  };
267 
268 public:
269  // For LLVM-style RTTI.
270  static bool classof(const SparseIterator *from) {
271  return from->kind == IterKind::kTrivial;
272  }
273 
274  bool isBatchIterator() const override {
275  return stl.getLT().isa<LevelFormat::Batch>();
276  }
277  bool randomAccessible() const override {
278  return stl.getLT().hasDenseSemantic();
279  };
280  bool iteratableByFor() const override { return kind != IterKind::kDedup; };
281  Value upperBound(OpBuilder &b, Location l) const override {
282  return stl.getSize();
283  };
284 
285 protected:
286  const SparseTensorLevel &stl;
287  // Owner of the storage, all wrappers build on top of a concrete iterator
288  // share the same storage such that the iterator values are always
289  // synchronized.
290  SmallVector<Value> cursorValsStorage;
291 };
292 
293 class TrivialIterator : public ConcreteIterator {
294 public:
295  TrivialIterator(const SparseTensorLevel &stl)
296  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
297 
298  std::string getDebugInterfacePrefix() const override {
299  return std::string("trivial<") + stl.toString() + ">";
300  }
301  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
302  return {b.getIndexType()};
303  }
304 
305  SmallVector<Value> serialize() const override {
306  SmallVector<Value> ret;
307  ret.push_back(getItPos());
308  if (randomAccessible()) {
309  // Loop high is implicit (defined by `upperBound()`) for random-access
310  // iterator, but we need to memorize posLo for linearization.
311  ret.push_back(posLo);
312  } else {
313  ret.push_back(posHi);
314  }
315  return ret;
316  };
317 
318  void deserialize(ValueRange vs) override {
319  assert(vs.size() == 2);
320  seek(vs.front());
321  if (randomAccessible())
322  posLo = vs.back();
323  else
324  posHi = vs.back();
325  };
326 
327  void genInitImpl(OpBuilder &b, Location l,
328  const SparseIterator *parent) override {
329 
330  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
331  batchCrds.resize(stl.lvl + 1, nullptr);
332 
333  Value c0 = C_IDX(0);
334  ValueRange pPos = c0;
335  // If the parent iterator is a batch iterator, we also start from 0 (but
336  // on a different batch).
337  if (parent && !parent->isBatchIterator())
338  pPos = parent->getCurPosition();
339 
340  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
341  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
342  // Seek to the lowest position.
343  seek(posLo);
344  }
345 
346  ValuePair genForCond(OpBuilder &b, Location l) override {
347  if (randomAccessible())
348  return {deref(b, l), upperBound(b, l)};
349  return std::make_pair(getItPos(), posHi);
350  }
351 
352  Value genNotEndImpl(OpBuilder &b, Location l) override {
353  // We used the first level bound as the bound the collapsed set of levels.
354  return CMPI(ult, getItPos(), posHi);
355  }
356 
357  Value derefImpl(OpBuilder &b, Location l) override {
358  if (randomAccessible()) {
359  updateCrd(SUBI(getItPos(), posLo));
360  } else {
361  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
362  }
363  return getCrd();
364  };
365 
366  ValueRange forwardImpl(OpBuilder &b, Location l) override {
367  seek(ADDI(getItPos(), C_IDX(1)));
368  return getCursor();
369  }
370 
371  ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
372  Value curPos = getCursor().front();
373  Value nxPos = forward(b, l).front();
374  seek(SELECT(cond, nxPos, curPos));
375  return getCursor();
376  }
377 
378  void locateImpl(OpBuilder &b, Location l, Value crd) override {
379  assert(randomAccessible());
380  // Seek to the linearized position.
381  seek(ADDI(crd, posLo));
382  updateCrd(crd);
383  if (isBatchIterator()) {
384  // If this is a batch iterator, also update the batch coordinate.
385  assert(batchCrds.size() > lvl);
386  batchCrds[lvl] = crd;
387  }
388  }
389 
390  Value getItPos() const { return getCursor().front(); }
391  Value posLo, posHi;
392 };
393 
394 class DedupIterator : public ConcreteIterator {
395 private:
396  Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
397 
398 public:
399  DedupIterator(const SparseTensorLevel &stl)
400  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
401  assert(!stl.isUnique());
402  }
403  // For LLVM-style RTTI.
404  static bool classof(const SparseIterator *from) {
405  return from->kind == IterKind::kDedup;
406  }
407 
408  std::string getDebugInterfacePrefix() const override {
409  return std::string("dedup<") + stl.toString() + ">";
410  }
411  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
412  return {b.getIndexType(), b.getIndexType()};
413  }
414 
415  void genInitImpl(OpBuilder &b, Location l,
416  const SparseIterator *parent) override {
417  Value c0 = C_IDX(0);
418  ValueRange pPos = c0;
419 
420  // If the parent iterator is a batch iterator, we also start from 0 (but
421  // on a different batch).
422  if (parent && !parent->isBatchIterator())
423  pPos = parent->getCurPosition();
424 
425  Value posLo;
426  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
427  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
428 
429  seek({posLo, genSegmentHigh(b, l, posLo)});
430  }
431 
432  SmallVector<Value> serialize() const override {
433  SmallVector<Value> ret;
434  ret.append(getCursor().begin(), getCursor().end());
435  ret.push_back(posHi);
436  return ret;
437  };
438  void deserialize(ValueRange vs) override {
439  assert(vs.size() == 3);
440  seek(vs.take_front(getCursor().size()));
441  posHi = vs.back();
442  };
443 
444  Value genNotEndImpl(OpBuilder &b, Location l) override {
445  return CMPI(ult, getPos(), posHi);
446  }
447 
448  Value derefImpl(OpBuilder &b, Location l) override {
449  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
450  return getCrd();
451  };
452 
453  ValueRange forwardImpl(OpBuilder &b, Location l) override {
454  Value nxPos = getSegHi(); // forward the position to the next segment.
455  seek({nxPos, genSegmentHigh(b, l, nxPos)});
456  return getCursor();
457  }
458 
459  Value getPos() const { return getCursor()[0]; }
460  Value getSegHi() const { return getCursor()[1]; }
461 
462  Value posHi;
463 };
464 
465 //
466 // A filter iterator wrapped from another iterator. The filter iterator update
467 // the wrapped iterator *in-place*.
468 //
469 class FilterIterator : public SparseIterator {
470  // Coorindate translation between crd loaded from the wrap iterator and the
471  // filter iterator.
472  Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
473  // crd = (wrapCrd - offset) / stride
474  return DIVUI(SUBI(wrapCrd, offset), stride);
475  }
476  Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
477  // wrapCrd = crd * stride + offset
478  return ADDI(MULI(crd, stride), offset);
479  }
480 
481  Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
482 
483  Value genShouldFilter(OpBuilder &b, Location l);
484 
485 public:
486  // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
487  // when crd always < size.
488  FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
489  Value stride, Value size)
490  : SparseIterator(IterKind::kFilter, *wrap), offset(offset),
491  stride(stride), size(size), wrap(std::move(wrap)) {}
492 
493  // For LLVM-style RTTI.
494  static bool classof(const SparseIterator *from) {
495  return from->kind == IterKind::kFilter;
496  }
497 
498  std::string getDebugInterfacePrefix() const override {
499  return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
500  }
501  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
502  return wrap->getCursorValTypes(b);
503  }
504 
505  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
506  bool randomAccessible() const override { return wrap->randomAccessible(); };
507  bool iteratableByFor() const override { return randomAccessible(); };
508  Value upperBound(OpBuilder &b, Location l) const override { return size; };
509 
510  SmallVector<Value> serialize() const override { return wrap->serialize(); };
511  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
512  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
513 
514  void genInitImpl(OpBuilder &b, Location l,
515  const SparseIterator *parent) override {
516  wrap->genInit(b, l, parent);
517  if (!randomAccessible()) {
518  // TODO: we can skip this when stride == 1 and offset == 0, we can also
519  // use binary search here.
520  forwardIf(b, l, genShouldFilter(b, l));
521  } else {
522  // Else, locate to the slice.offset, which is the first coordinate
523  // included by the slice.
524  wrap->locate(b, l, offset);
525  }
526  }
527 
528  Value genNotEndImpl(OpBuilder &b, Location l) override;
529 
530  Value derefImpl(OpBuilder &b, Location l) override {
531  updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
532  return getCrd();
533  }
534 
535  void locateImpl(OpBuilder &b, Location l, Value crd) override {
536  assert(randomAccessible());
537  wrap->locate(b, l, toWrapCrd(b, l, crd));
538  updateCrd(crd);
539  }
540 
541  ValueRange forwardImpl(OpBuilder &b, Location l) override;
542 
543  Value offset, stride, size;
544  std::unique_ptr<SparseIterator> wrap;
545 };
546 
547 class NonEmptySubSectIterator : public SparseIterator {
548 public:
549  using TraverseBuilder = llvm::function_ref<scf::ValueVector(
551 
552  NonEmptySubSectIterator(OpBuilder &b, Location l,
553  const SparseIterator *parent,
554  std::unique_ptr<SparseIterator> &&delegate,
555  Value subSectSz)
556  : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
557  parent(parent), delegate(std::move(delegate)),
558  tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
559  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
560  if (p == nullptr) {
561  // Extract subsections along the root level.
562  maxTupleCnt = C_IDX(1);
563  } else if (p->lvl == lvl) {
564  // Extract subsections along the same level.
565  maxTupleCnt = p->maxTupleCnt;
566  assert(false && "Not implemented.");
567  } else {
568  // Extract subsections along the previous level.
569  assert(p->lvl + 1 == lvl);
570  maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
571  }
572  // We don't need an extra buffer to find subsections on random-accessible
573  // levels.
574  if (randomAccessible())
575  return;
576  subSectPosBuf = allocSubSectPosBuf(b, l);
577  }
578 
579  // For LLVM-style RTTI.
580  static bool classof(const SparseIterator *from) {
581  return from->kind == IterKind::kNonEmptySubSect;
582  }
583 
584  std::string getDebugInterfacePrefix() const override {
585  return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
586  }
587  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
588  // minCrd, absolute offset, notEnd
589  return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
590  }
591 
592  // The sliced pointer buffer is organized as:
593  // [[itVal0, itVal1, ..., pNx0],
594  // [itVal0, itVal1, ..., pNx0],
595  // ...]
596  Value allocSubSectPosBuf(OpBuilder &b, Location l) {
597  return b.create<memref::AllocaOp>(
598  l,
599  MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
600  maxTupleCnt);
601  }
602 
603  void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
604  Value start) const {
605  b.create<memref::StoreOp>(l, start, subSectPosBuf,
606  ValueRange{tupleId, C_IDX(tupleSz)});
607  }
608 
609  Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
610  return b.create<memref::LoadOp>(l, subSectPosBuf,
611  ValueRange{tupleId, C_IDX(tupleSz)});
612  }
613 
614  void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
615  ValueRange itVals) const {
616  assert(itVals.size() == tupleSz);
617  for (unsigned i = 0; i < tupleSz; i++) {
618  b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
619  ValueRange{tupleId, C_IDX(i)});
620  }
621  }
622 
623  SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
624  Value tupleId) const {
625  SmallVector<Value> ret;
626  for (unsigned i = 0; i < tupleSz; i++) {
627  Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
628  ValueRange{tupleId, C_IDX(i)});
629  ret.push_back(v);
630  }
631  return ret;
632  }
633 
634  bool isSubSectRoot() const {
635  return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
636  }
637 
638  // Generate code that inflate the current subsection tree till the current
639  // level such that every leaf node is visited.
640  ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
641  TraverseBuilder builder) const;
642 
643  bool isBatchIterator() const override { return delegate->isBatchIterator(); }
644  bool randomAccessible() const override {
645  return delegate->randomAccessible();
646  };
647  bool iteratableByFor() const override { return randomAccessible(); };
648  Value upperBound(OpBuilder &b, Location l) const override {
649  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
650  Value parentUB =
651  p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
652  return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
653  };
654 
655  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
656 
657  void locateImpl(OpBuilder &b, Location l, Value crd) override {
658  Value absOff = crd;
659 
660  if (isSubSectRoot())
661  delegate->locate(b, l, absOff);
662  else
663  assert(parent->lvl + 1 == lvl);
664 
665  seek(ValueRange{absOff, absOff, C_TRUE});
666  updateCrd(crd);
667  }
668 
669  Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
670  return SUBI(wrapCrd, getAbsOff());
671  }
672 
673  Value genNotEndImpl(OpBuilder &b, Location l) override {
674  return getNotEnd();
675  };
676 
677  Value derefImpl(OpBuilder &b, Location l) override {
678  // Use the relative offset to coiterate.
679  Value crd;
680  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
681  if (p && p->lvl == lvl)
682  crd = SUBI(getAbsOff(), p->getAbsOff());
683  crd = getAbsOff();
684 
685  updateCrd(crd);
686  return crd;
687  };
688 
689  ValueRange forwardImpl(OpBuilder &b, Location l) override;
690 
691  Value getMinCrd() const { return subSectMeta[0]; }
692  Value getAbsOff() const { return subSectMeta[1]; }
693  Value getNotEnd() const { return subSectMeta[2]; }
694 
695  const SparseIterator *parent;
696  std::unique_ptr<SparseIterator> delegate;
697 
698  // Number of values required to serialize the wrapped iterator.
699  const unsigned tupleSz;
700  // Max number of tuples, and the actual number of tuple.
701  Value maxTupleCnt, tupleCnt;
702  // The memory used to cache the tuple serialized from the wrapped iterator.
703  Value subSectPosBuf;
704 
705  const Value subSectSz;
706 
707  // minCrd, absolute offset, notEnd
708  SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
709 };
710 
711 class SubSectIterator;
712 
713 // A wrapper that helps generating code to traverse a subsection, used
714 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
715 struct SubSectIterHelper {
716  explicit SubSectIterHelper(const SubSectIterator &iter);
717  explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
718 
719  // Delegate methods.
720  void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
721  void locate(OpBuilder &b, Location l, Value crd);
722  Value genNotEnd(OpBuilder &b, Location l);
723  Value deref(OpBuilder &b, Location l);
724  ValueRange forward(OpBuilder &b, Location l);
725 
726  const NonEmptySubSectIterator &subSect;
728 };
729 
730 class SubSectIterator : public SparseIterator {
731 public:
732  SubSectIterator(const NonEmptySubSectIterator &subSect,
733  const SparseIterator &parent,
734  std::unique_ptr<SparseIterator> &&wrap)
736  /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
737  subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
738  assert(subSect.tid == tid && subSect.lvl == lvl);
739  assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
740  };
741 
742  // For LLVM-style RTTI.
743  static bool classof(const SparseIterator *from) {
744  return from->kind == IterKind::kSubSect;
745  }
746 
747  std::string getDebugInterfacePrefix() const override {
748  return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
749  }
750  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
751  SmallVector<Type> ret = wrap->getCursorValTypes(b);
752  if (!randomAccessible())
753  ret.push_back(b.getIndexType()); // The extra counter.
754  return ret;
755  }
756 
757  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
758  bool randomAccessible() const override { return wrap->randomAccessible(); };
759  bool iteratableByFor() const override { return randomAccessible(); };
760  Value upperBound(OpBuilder &b, Location l) const override {
761  return subSect.subSectSz;
762  }
763 
764  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
765 
766  Value getNxLvlTupleId(OpBuilder &b, Location l) const {
767  if (randomAccessible()) {
768  return ADDI(getCrd(), nxLvlTupleStart);
769  };
770  return ADDI(getCursor().back(), nxLvlTupleStart);
771  }
772 
773  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
774  if (randomAccessible()) {
775  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
776  assert(p->lvl + 1 == lvl);
777  wrap->genInit(b, l, p);
778  // Linearize the dense subsection index.
779  nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
780  } else {
781  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
782  wrap->deserialize(subSect.delegate->serialize());
783  nxLvlTupleStart = C_IDX(0);
784  }
785  return;
786  }
787  assert(!randomAccessible());
788  assert(getCursor().size() == wrap->getCursor().size() + 1);
789  // Extra counter that counts the number of actually visited coordinates in
790  // the sparse subsection.
791  getMutCursorVals().back() = C_IDX(0);
792  Value tupleId;
793  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
794  assert(p->lvl + 1 == lvl);
795  tupleId = p->getNxLvlTupleId(b, l);
796  } else {
797  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
798  tupleId = C_IDX(0);
799  }
800  nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
801  helper.deserializeFromTupleId(b, l, tupleId);
802  }
803 
804  void locateImpl(OpBuilder &b, Location l, Value crd) override {
805  helper.locate(b, l, crd);
806  updateCrd(crd);
807  }
808 
809  Value genNotEndImpl(OpBuilder &b, Location l) override {
810  return helper.genNotEnd(b, l);
811  }
812 
813  Value derefImpl(OpBuilder &b, Location l) override {
814  Value crd = helper.deref(b, l);
815  updateCrd(crd);
816  return crd;
817  };
818 
819  ValueRange forwardImpl(OpBuilder &b, Location l) override {
820  helper.forward(b, l);
821  assert(!randomAccessible());
822  assert(getCursor().size() == wrap->getCursor().size() + 1);
823  getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
824  return getCursor();
825  };
826 
827  Value nxLvlTupleStart;
828 
829  const NonEmptySubSectIterator &subSect;
830  std::unique_ptr<SparseIterator> wrap;
831  const SparseIterator &parent;
832 
833  SubSectIterHelper helper;
834 };
835 
836 } // namespace
837 
838 //===----------------------------------------------------------------------===//
839 // SparseIterator derived classes implementation.
840 //===----------------------------------------------------------------------===//
841 
843  const SparseIterator *p) {
845  std::string prefix = getDebugInterfacePrefix();
846  Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
847  getCursorValTypes(b));
848  seek(begin->getResults());
849  return;
850  }
851  // Inherent batch coordinates from parents.
852  if (p)
853  inherentBatch(*p);
854  // TODO: support lowering to function call.
855  return genInitImpl(b, l, p);
856 }
857 
860  std::string prefix = getDebugInterfacePrefix();
861  Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
862  getCursor(), b.getI1Type());
863  return notEnd->getResult(0);
864  }
865  // TODO: support lowering to function call.
866  return genNotEndImpl(b, l);
867 }
868 
871  std::string prefix = getDebugInterfacePrefix();
872  SmallVector<Value> args = getCursor();
873  args.push_back(crd);
874  Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
875  getCursorValTypes(b));
876  seek(locate->getResults());
877  updateCrd(crd);
878  return;
879  }
880  return locateImpl(b, l, crd);
881 }
882 
885  std::string prefix = getDebugInterfacePrefix();
886  SmallVector<Value> args = getCursor();
887  Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
888  getCursor(), b.getIndexType());
889  updateCrd(deref->getResult(0));
890  return getCrd();
891  }
892  return derefImpl(b, l);
893 }
894 
896  assert(!randomAccessible());
898  std::string prefix = getDebugInterfacePrefix();
899  Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
901  seek(next->getResults());
902  return getCursor();
903  }
904  return forwardImpl(b, l);
905 }
906 
908  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
909  // Generate else branch first, otherwise iterator values will be updated by
910  // `forward()`.
911  b.setInsertionPointToStart(ifOp.elseBlock());
912  YIELD(getCursor());
913 
914  b.setInsertionPointToStart(ifOp.thenBlock());
915  YIELD(forward(b, l));
916 
917  b.setInsertionPointAfter(ifOp);
918  seek(ifOp.getResults());
919  return getCursor();
920 }
921 
922 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
923  auto whileOp = b.create<scf::WhileOp>(
924  l, pos.getType(), pos,
925  /*beforeBuilder=*/
926  [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
927  Value inBound = CMPI(ult, ivs.front(), posHi);
928  auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
929  {
930  OpBuilder::InsertionGuard guard(b);
931  // If in bound, load the next coordinates and check duplication.
932  b.setInsertionPointToStart(ifInBound.thenBlock());
933  Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
934  Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
935  Value isDup = CMPI(eq, headCrd, tailCrd);
936  YIELD(isDup);
937  // Else, the position is out of bound, yield false.
938  b.setInsertionPointToStart(ifInBound.elseBlock());
939  YIELD(constantI1(b, l, false));
940  }
941  b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
942  },
943  /*afterBuilder=*/
944  [](OpBuilder &b, Location l, ValueRange ivs) {
945  Value nxPos = ADDI(ivs[0], C_IDX(1));
946  YIELD(nxPos);
947  });
948  // Return the segment high.
949  return whileOp.getResult(0);
950 }
951 
952 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
953  Value wrapCrd) {
954  Value crd = fromWrapCrd(b, l, wrapCrd);
955  // Test whether the coordinate is on stride.
956  Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
957  // Test wrapCrd < offset
958  notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
959  // Test crd >= length
960  notlegit = ORI(CMPI(uge, crd, size), notlegit);
961  return notlegit;
962 }
963 
964 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
965  auto r = genWhenInBound(
966  b, l, *wrap, C_FALSE,
967  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
968  Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
969  return {notLegit};
970  });
971 
972  assert(r.size() == 1);
973  return r.front();
974 }
975 
976 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
977  assert(!wrap->randomAccessible());
978  auto r = genWhenInBound(
979  b, l, *wrap, C_FALSE,
980  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
981  Value crd = fromWrapCrd(b, l, wrapCrd);
982  // crd < size
983  return {CMPI(ult, crd, size)};
984  });
985  assert(r.size() == 1);
986  return r.front();
987 }
988 
989 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
990  assert(!randomAccessible());
991  // Generates
992  //
993  // bool isFirst = true;
994  // while !it.end() && (!legit(*it) || isFirst)
995  // wrap ++;
996  // isFirst = false;
997  //
998  // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
999  // flag here because `wrap++` might have a complex implementation (e.g., to
1000  // forward a subsection).
1001  Value isFirst = constantI1(b, l, true);
1002 
1003  SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1004  whileArgs.push_back(isFirst);
1005  auto whileOp = b.create<scf::WhileOp>(
1006  l, ValueRange(whileArgs).getTypes(), whileArgs,
1007  /*beforeBuilder=*/
1008  [this](OpBuilder &b, Location l, ValueRange ivs) {
1009  ValueRange isFirst = linkNewScope(ivs);
1010  assert(isFirst.size() == 1);
1011  scf::ValueVector cont =
1012  genWhenInBound(b, l, *wrap, C_FALSE,
1013  [this, isFirst](OpBuilder &b, Location l,
1014  Value wrapCrd) -> scf::ValueVector {
1015  // crd < size && !legit();
1016  Value notLegit =
1017  genCrdNotLegitPredicate(b, l, wrapCrd);
1018  Value crd = fromWrapCrd(b, l, wrapCrd);
1019  Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1020  ret = ORI(ret, isFirst.front());
1021  return {ret};
1022  });
1023  b.create<scf::ConditionOp>(l, cont.front(), ivs);
1024  },
1025  /*afterBuilder=*/
1026  [this](OpBuilder &b, Location l, ValueRange ivs) {
1027  linkNewScope(ivs);
1028  wrap->forward(b, l);
1029  SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1030  yieldVals.push_back(constantI1(b, l, false));
1031  YIELD(yieldVals);
1032  });
1033 
1034  b.setInsertionPointAfter(whileOp);
1035  linkNewScope(whileOp.getResults());
1036  return getCursor();
1037 }
1038 
1039 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1040  : subSect(subSect), wrap(*subSect.delegate) {}
1041 
1042 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1043  : subSect(iter.subSect), wrap(*iter.wrap) {}
1044 
1045 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1046  Value tupleId) {
1047  assert(!subSect.randomAccessible());
1048  wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1049 }
1050 
1051 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1052  Value absCrd = ADDI(crd, subSect.getAbsOff());
1053  wrap.locate(b, l, absCrd);
1054 }
1055 
1056 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1057  assert(!wrap.randomAccessible());
1058  auto r = genWhenInBound(
1059  b, l, wrap, C_FALSE,
1060  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1061  Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1062  // crd < size
1063  return {CMPI(ult, crd, subSect.subSectSz)};
1064  });
1065  assert(r.size() == 1);
1066  return r.front();
1067 }
1068 
1069 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1070  Value wrapCrd = wrap.deref(b, l);
1071  Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1072  return crd;
1073 }
1074 
1075 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1076  return wrap.forward(b, l);
1077 }
1078 
1079 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1080  OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1081  // Set up the helper to help traverse a sparse subsection.
1082  SubSectIterHelper helper(*this);
1083  if (!randomAccessible()) {
1084  // The subsection tree have been expanded till the level and cached,
1085  // traverse all the leaves and expanded to the next level.
1086  SmallVector<Value> iterArgs;
1087  iterArgs.push_back(C_IDX(0));
1088  iterArgs.append(reduc.begin(), reduc.end());
1089  auto forEachLeaf = b.create<scf::ForOp>(
1090  l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1091  [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1092  ValueRange iterArgs) {
1093  // Deserialize the iterator at the cached position (tupleId).
1094  helper.deserializeFromTupleId(b, l, tupleId);
1095 
1096  Value cnt = iterArgs.front();
1097  // Record the number of leaf nodes included in the subsection.
1098  // The number indicates the starting tupleId for the next level that
1099  // is corresponding to the current node.
1100  helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1101 
1102  SmallVector<Value> whileArgs(helper.wrap.getCursor());
1103  whileArgs.append(iterArgs.begin(), iterArgs.end());
1104 
1105  auto whileOp = b.create<scf::WhileOp>(
1106  l, ValueRange(whileArgs).getTypes(), whileArgs,
1107  /*beforeBuilder=*/
1108  [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1109  helper.wrap.linkNewScope(ivs);
1110  b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1111  },
1112  /*afterBuilder=*/
1113  [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1114  ValueRange remIter = helper.wrap.linkNewScope(ivs);
1115  Value cnt = remIter.front();
1116  ValueRange userIter = remIter.drop_front();
1117  scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1118 
1119  SmallVector<Value> nxIter = helper.forward(b, l);
1120  nxIter.push_back(ADDI(cnt, C_IDX(1)));
1121  nxIter.append(userNx.begin(), userNx.end());
1122  YIELD(nxIter);
1123  });
1124  ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1125  YIELD(res);
1126  });
1127  return forEachLeaf.getResults().drop_front();
1128  }
1129 
1130  assert(randomAccessible());
1131  // Helper lambda that traverse the current dense subsection range.
1132  auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1133  const SparseIterator *parent,
1134  ValueRange reduc) {
1135  assert(!parent || parent->lvl + 1 == lvl);
1136  delegate->genInit(b, l, parent);
1137  auto forOp = b.create<scf::ForOp>(
1138  l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1139  [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1140  helper.locate(b, l, crd);
1141  scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1142  YIELD(nx);
1143  });
1144  return forOp.getResults();
1145  };
1146 
1147  if (isSubSectRoot()) {
1148  return visitDenseSubSect(b, l, parent, reduc);
1149  }
1150  // Else, this is not the root, recurse until root.
1151  auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1152  assert(p->lvl + 1 == lvl);
1153  return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1154 }
1155 
1156 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1157  const SparseIterator *) {
1158  Value c0 = C_IDX(0);
1159  if (!isSubSectRoot()) {
1160  assert(parent->lvl + 1 == lvl);
1161  if (randomAccessible()) {
1162  // We can not call wrap->genInit() here to initialize the wrapped
1163  // iterator, because the parent of the curent iterator is still
1164  // unresolved.
1165  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1166  return;
1167  }
1168 
1169  auto *p = cast<NonEmptySubSectIterator>(parent);
1170  SmallVector<Value, 3> reduc = {
1171  C_IDX(-1), // minCrd (max signless integer)
1172  c0, // tupleId
1173  };
1174 
1175  // Expand the subsection tree from the parent level to the current level.
1176  ValueRange result = p->inflateSubSectTree(
1177  b, l, reduc,
1178  [this](OpBuilder &b, Location l, const SparseIterator *parent,
1179  ValueRange reduc) -> scf::ValueVector {
1180  assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1181  Value minCrd = reduc.front();
1182  Value tupleId = reduc.back();
1183 
1184  // Initialize the subsection range.
1185  SubSectIterHelper helper(*this);
1186  helper.wrap.genInit(b, l, parent);
1187 
1188  // Update minCrd.
1189  minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1190  [minCrd](OpBuilder &b, Location l,
1191  Value crd) -> scf::ValueVector {
1192  Value min = MINUI(crd, minCrd);
1193  return {min};
1194  })
1195  .front();
1196 
1197  // Cache the sparse range.
1198  storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1199  tupleId = ADDI(tupleId, C_IDX(1));
1200  return {minCrd, tupleId};
1201  });
1202  assert(result.size() == 2);
1203  tupleCnt = result.back();
1204 
1205  Value minCrd = result.front();
1206  Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1207  Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1208  seek({minCrd, absOff, notEnd});
1209  return;
1210  }
1211 
1212  // This is the root level of the subsection, which means that it is resolved
1213  // to one node.
1214  assert(isSubSectRoot());
1215 
1216  // Initialize the position, the position marks the *lower bound* of the
1217  // subRange. The higher bound is determined by the size of the subsection.
1218  delegate->genInit(b, l, parent);
1219  if (randomAccessible()) {
1220  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1221  return;
1222  }
1223 
1224  // Only have one root node.
1225  tupleCnt = C_IDX(1);
1226  // Cache the sparse range.
1227  storeCursorVals(b, l, c0, delegate->serialize());
1228  SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1229  auto meta = genWhenInBound(
1230  b, l, *delegate, elseRet,
1231  [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1232  Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1233  return {crd, offset, C_TRUE};
1234  });
1235 
1236  seek(meta);
1237 }
1238 
1239 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1240  assert(!randomAccessible());
1241  Value c0 = C_IDX(0), c1 = C_IDX(1);
1242  // Forward to the next non empty slice by generating
1243  //
1244  // if (minCrd > offset) {
1245  // offset += 1
1246  // } else {
1247  // minCrd = nextMinInSlice();
1248  // offset = minCrd - size + 1;
1249  // }
1250  //
1251  // if (offset + size > parents.size)
1252  // isNonEmpty = false;
1253  Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1254  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1255  {
1256  OpBuilder::InsertionGuard guard(b);
1257  // Take the fast path
1258  // if (minCrd > offset)
1259  // offset += 1
1260  b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1261  Value nxOffset = ADDI(getAbsOff(), c1);
1262  YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1263 
1264  // else /*minCrd == offset*/ {
1265  // for (i = 0; i < tupleCnt; i++) {
1266  // wrap->deserialize(pos[i]);
1267  // minCrd=min(minCrd, *wrap);
1268  // }
1269  // offset = minCrd - size + 1;
1270  // }
1271  b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1272  SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1273  C_FALSE}; // isNotEnd
1274  auto loopNest = scf::buildLoopNest(
1275  b, l, c0, tupleCnt, c1, loopArgs,
1276  [this](OpBuilder &b, Location l, ValueRange ivs,
1277  ValueRange iterArgs) -> scf::ValueVector {
1278  Value tupleId = ivs.front();
1279  SubSectIterHelper helper(*this);
1280  helper.deserializeFromTupleId(b, l, tupleId);
1281 
1282  return genWhenInBound(
1283  b, l, *delegate, /*elseRet=*/iterArgs,
1284  [this, iterArgs, tupleId](OpBuilder &b, Location l,
1285  Value crd) -> scf::ValueVector {
1286  // if coord == minCrd
1287  // wrap->forward();
1288  Value isMin = CMPI(eq, crd, getMinCrd());
1289  delegate->forwardIf(b, l, isMin);
1290  // Update the forwarded iterator values if needed.
1291  auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1292  b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1293  storeCursorVals(b, l, tupleId, delegate->serialize());
1294  b.setInsertionPointAfter(ifIsMin);
1295  // if (!wrap.end())
1296  // yield(min(nxMinCrd, *wrap), true)
1297  Value nxMin = iterArgs[0];
1298  return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1299  [nxMin](OpBuilder &b, Location l,
1300  Value crd) -> scf::ValueVector {
1301  Value nx = b.create<arith::MinUIOp>(
1302  l, crd, nxMin);
1303  return {nx, C_TRUE};
1304  });
1305  });
1306  });
1307 
1308  scf::ForOp forOp = loopNest.loops.front();
1309  b.setInsertionPointAfter(forOp);
1310 
1311  Value nxMinCrd = forOp.getResult(0);
1312  Value nxNotEnd = forOp.getResult(1);
1313  Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1314  YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1315  }
1316 
1317  Value nxMinCrd = ifOp.getResult(0);
1318  Value nxAbsOff = ifOp.getResult(1);
1319  Value nxNotEnd = ifOp.getResult(2);
1320 
1321  // We should at least forward the offset by one.
1322  Value minAbsOff = ADDI(getAbsOff(), c1);
1323  nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1324 
1325  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1326  // The coordinate should not exceeds the space upper bound.
1327  Value crd = deref(b, l);
1328  nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1329 
1330  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1331  return getCursor();
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // SparseIterator factory functions.
1336 //===----------------------------------------------------------------------===//
1337 
1338 std::unique_ptr<SparseTensorLevel>
1340  unsigned tid, Level lvl) {
1341  auto stt = getSparseTensorType(t);
1342 
1343  LevelType lt = stt.getLvlType(lvl);
1344  Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1345  : b.create<tensor::DimOp>(l, t, lvl).getResult();
1346 
1347  switch (lt.getLvlFmt()) {
1348  case LevelFormat::Dense:
1349  return std::make_unique<DenseLevel>(tid, lvl, sz);
1350  case LevelFormat::Batch:
1351  return std::make_unique<BatchLevel>(tid, lvl, sz);
1352  case LevelFormat::Compressed: {
1353  Value pos = b.create<ToPositionsOp>(l, t, lvl);
1354  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1355  return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
1356  }
1357  case LevelFormat::LooseCompressed: {
1358  Value pos = b.create<ToPositionsOp>(l, t, lvl);
1359  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1360  return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
1361  }
1362  case LevelFormat::Singleton: {
1363  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1364  return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
1365  }
1366  case LevelFormat::NOutOfM: {
1367  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1368  return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
1369  }
1370  case LevelFormat::Undef:
1371  llvm_unreachable("undefined level format");
1372  }
1373  llvm_unreachable("unrecognizable level format");
1374 }
1375 
1376 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1377 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1378  SparseEmitStrategy strategy) {
1379  auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1380  auto it = std::make_unique<TrivialIterator>(*stl);
1381  it->setSparseEmitStrategy(strategy);
1382  return std::make_pair(std::move(stl), std::move(it));
1383 }
1384 
1385 std::unique_ptr<SparseIterator>
1387  SparseEmitStrategy strategy) {
1388  std::unique_ptr<SparseIterator> ret;
1389  if (!isUniqueLT(stl.getLT())) {
1390  // We always dedupliate the non-unique level, but we should optimize it away
1391  // if possible.
1392  ret = std::make_unique<DedupIterator>(stl);
1393  } else {
1394  ret = std::make_unique<TrivialIterator>(stl);
1395  }
1396  ret->setSparseEmitStrategy(strategy);
1397  return ret;
1398 }
1399 
1400 std::unique_ptr<SparseIterator>
1401 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1402  Value offset, Value stride, Value size,
1403  SparseEmitStrategy strategy) {
1404 
1405  auto ret =
1406  std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1407  ret->setSparseEmitStrategy(strategy);
1408  return ret;
1409 }
1410 
1412  auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1413  if (filter)
1414  return filter->wrap.get();
1415  return it;
1416 }
1417 
1418 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1419  OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1420  std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1421  SparseEmitStrategy strategy) {
1422 
1423  // Try unwrap the NonEmptySubSectIterator from a filter parent.
1424  parent = tryUnwrapFilter(parent);
1425  std::unique_ptr<SparseIterator> it =
1426  std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1427  std::move(delegate), size);
1428 
1429  if (stride != 1) {
1430  // TODO: We can safely skip bound checking on sparse levels, but for dense
1431  // iteration space, we need the bound to infer the dense loop range.
1432  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1433  C_IDX(stride), /*size=*/loopBound);
1434  }
1435  it->setSparseEmitStrategy(strategy);
1436  return it;
1437 }
1438 
1439 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1440  OpBuilder &b, Location l, const SparseIterator &subSectIter,
1441  const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1442  Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1443 
1444  // This must be a subsection iterator or a filtered subsection iterator.
1445  auto &subSect =
1446  llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1447 
1448  std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1449  subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1450 
1451  if (stride != 1) {
1452  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1453  C_IDX(stride), /*size=*/loopBound);
1454  }
1455  it->setSparseEmitStrategy(strategy);
1456  return it;
1457 }
1458 
1459 #undef CMPI
1460 #undef C_IDX
1461 #undef YIELD
1462 #undef ADDI
1463 #undef ANDI
1464 #undef SUBI
1465 #undef MULI
1466 #undef SELECT
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:112
#define SELECT(c, lhs, rhs)
#define C_FALSE
#define SUBI(lhs, rhs)
#define MULI(lhs, rhs)
#define C_IDX(v)
#define ANDI(lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
#define C_TRUE
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
#define YIELD(vs)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
#define ORI(lhs, rhs)
#define DIVUI(lhs, rhs)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
#define ADDI(lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
result_range getResults()
Definition: Operation.h:410
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool isUniqueLT(LevelType lt)
Definition: Enums.h:424
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over a non-empty subsection created by...
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
uint64_t getN(LevelType lt)
Definition: Enums.h:438
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:359
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::unique_ptr< SparseIterator > makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy)
Helper function to create a simple SparseIterator object that iterate over the SparseTensorLevel.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterate over a dense space specified...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context)
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:51
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LoopVector loops
Definition: SCF.h:73
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
Definition: Enums.h:320
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326