MLIR  19.0.0git
SparseTensorIterator.cpp
Go to the documentation of this file.
1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11 
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20 
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs) \
25  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26  .getResult())
27 
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41  (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42 
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51  // It is either an array of size 2 or size 1 depending on whether the sparse
52  // level requires a position array.
53  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54  std::array<Value, 1>>;
55 
56 public:
57  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58  BufferT buffers)
59  : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60 
61  ValueRange getLvlBuffers() const override { return buffers; }
62 
63  Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64  Value iv) const override {
65  SmallVector<Value> memCrd(batchPrefix);
66  memCrd.push_back(iv);
67  return genIndexLoad(b, l, getCrdBuf(), memCrd);
68  }
69 
70 protected:
71  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72  Value getPosBuf() const {
73  return buffers[0];
74  }
75 
76  Value getCrdBuf() const {
77  if constexpr (hasPosBuffer)
78  return buffers[1];
79  else
80  return buffers[0];
81  }
82 
83  const BufferT buffers;
84 };
85 
86 class DenseLevel : public SparseTensorLevel {
87 public:
88  DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89  : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90 
91  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92  llvm_unreachable("locate random-accessible level instead");
93  }
94 
95  ValueRange getLvlBuffers() const override { return {}; }
96 
97  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98  ValueRange parentPos, Value inPadZone) const override {
99  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100  assert(!inPadZone && "Not implemented");
101  Value p = parentPos.front();
102  Value posLo = MULI(p, lvlSize);
103  return {posLo, lvlSize};
104  }
105 };
106 
107 class BatchLevel : public SparseTensorLevel {
108 public:
109  BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110  : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111 
112  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113  llvm_unreachable("locate random-accessible level instead");
114  }
115 
116  ValueRange getLvlBuffers() const override { return {}; }
117 
118  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119  ValueRange parentPos, Value inPadZone) const override {
120  assert(!inPadZone && "Not implemented");
121  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122  // No need to linearize the position for non-annotated tensors.
123  return {C_IDX(0), lvlSize};
124  }
125 };
126 
127 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128 public:
129  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130  Value posBuffer, Value crdBuffer)
131  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132 
133  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134  ValueRange parentPos, Value inPadZone) const override {
135 
136  assert(parentPos.size() == 1 &&
137  "compressed level must be the first non-unique level.");
138 
139  auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140  Value p = parentPos.front();
141  SmallVector<Value> memCrd(batchPrefix);
142  memCrd.push_back(p);
143  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144  memCrd.back() = ADDI(p, C_IDX(1));
145  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146  return {pLo, pHi};
147  };
148 
149  if (inPadZone == nullptr)
150  return loadRange();
151 
153  scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154  // True branch, returns a "fake" empty range [0, 0) if parent
155  // iterator is in pad zone.
156  b.setInsertionPointToStart(posRangeIf.thenBlock());
157 
158  SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159  b.create<scf::YieldOp>(l, emptyRange);
160 
161  // False branch, returns the actual range.
162  b.setInsertionPointToStart(posRangeIf.elseBlock());
163  auto [pLo, pHi] = loadRange();
164  SmallVector<Value, 2> loadedRange{pLo, pHi};
165  b.create<scf::YieldOp>(l, loadedRange);
166 
167  b.setInsertionPointAfter(posRangeIf);
168  ValueRange posRange = posRangeIf.getResults();
169  return {posRange.front(), posRange.back()};
170  }
171 };
172 
173 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174 public:
175  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176  Value posBuffer, Value crdBuffer)
177  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178 
179  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180  ValueRange parentPos, Value inPadZone) const override {
181  assert(parentPos.size() == 1 &&
182  "loose-compressed level must be the first non-unique level.");
183  assert(!inPadZone && "Not implemented");
184  SmallVector<Value> memCrd(batchPrefix);
185  Value p = parentPos.front();
186  p = MULI(p, C_IDX(2));
187  memCrd.push_back(p);
188  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189  memCrd.back() = ADDI(p, C_IDX(1));
190  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191  return {pLo, pHi};
192  }
193 };
194 
195 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196 public:
197  SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198  Value crdBuffer)
199  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200 
201  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202  ValueRange parentPos, Value inPadZone) const override {
203  assert(parentPos.size() == 1 || parentPos.size() == 2);
204  assert(!inPadZone && "Not implemented");
205  Value p = parentPos.front();
206  Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207 
208  if (segHi == nullptr)
209  return {p, ADDI(p, C_IDX(1))};
210  // Use the segHi as the loop upper bound.
211  return {p, segHi};
212  }
213 };
214 
215 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
216 public:
217  NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
218  Value crdBuffer)
219  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
220 
221  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
222  ValueRange parentPos, Value inPadZone) const override {
223  assert(parentPos.size() == 1 && isUnique() &&
224  "n:m level can not be non-unique.");
225  assert(!inPadZone && "Not implemented");
226  // Each n:m blk has exactly n specified elements.
227  auto n = getN(lt);
228  Value posLo = MULI(parentPos.front(), C_IDX(n));
229  return {posLo, ADDI(posLo, C_IDX(n))};
230  }
231 };
232 
233 } // namespace
234 
235 //===----------------------------------------------------------------------===//
236 // File local helpers
237 //===----------------------------------------------------------------------===//
238 
240  OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
242  builder) {
243  TypeRange ifRetTypes = elseRet.getTypes();
244  auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
245 
246  b.setInsertionPointToStart(ifOp.thenBlock());
247  Value crd = it.deref(b, l);
248  scf::ValueVector ret = builder(b, l, crd);
249  YIELD(ret);
250 
251  b.setInsertionPointToStart(ifOp.elseBlock());
252  YIELD(elseRet);
253 
254  b.setInsertionPointAfter(ifOp);
255  return ifOp.getResults();
256 }
257 
258 /// Generates code to compute the *absolute* offset of the slice based on the
259 /// provide minimum coordinates in the slice.
260 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
261 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
262 /// offset is the offset computed relative to the initial tensors T.
263 ///
264 /// When isNonEmpty == true, the computed offset is meaningless and should not
265 /// be used during runtime, the method generates code to return 0 currently in
266 /// that case.
267 ///
268 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
270  Value size) {
271  Value geSize = CMPI(uge, minCrd, size);
272  // Compute minCrd - size + 1.
273  Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
274  // This is the absolute offset related to the actual tensor.
275  return SELECT(geSize, mms, C_IDX(0));
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // SparseIterator derived classes.
280 //===----------------------------------------------------------------------===//
281 
282 namespace {
283 
284 // The iterator that traverses a concrete sparse tensor levels. High-level
285 // abstract iterators wrap it to achieve more complex goals (such as collapsing
286 // several levels). It also holds the common storage to hold the mlir::Values
287 // for itself as well as for wrappers.
288 class ConcreteIterator : public SparseIterator {
289 protected:
290  ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
291  unsigned cursorValCnt)
292  : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
293  stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
294  assert(getCursor().size() == cursorValCnt);
295  };
296 
297 public:
298  // For LLVM-style RTTI.
299  static bool classof(const SparseIterator *from) {
300  return from->kind == IterKind::kTrivial;
301  }
302 
303  bool isBatchIterator() const override {
304  return stl.getLT().isa<LevelFormat::Batch>();
305  }
306  bool randomAccessible() const override {
307  return stl.getLT().hasDenseSemantic();
308  };
309  bool iteratableByFor() const override { return kind != IterKind::kDedup; };
310  Value upperBound(OpBuilder &b, Location l) const override {
311  return stl.getSize();
312  };
313 
314 protected:
315  const SparseTensorLevel &stl;
316  // Owner of the storage, all wrappers build on top of a concrete iterator
317  // share the same storage such that the iterator values are always
318  // synchronized.
319  SmallVector<Value> cursorValsStorage;
320 };
321 
322 class TrivialIterator : public ConcreteIterator {
323 public:
324  TrivialIterator(const SparseTensorLevel &stl)
325  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
326 
327  std::string getDebugInterfacePrefix() const override {
328  return std::string("trivial<") + stl.toString() + ">";
329  }
330  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
331  return {b.getIndexType()};
332  }
333 
334  SmallVector<Value> serialize() const override {
335  SmallVector<Value> ret;
336  ret.push_back(getItPos());
337  if (randomAccessible()) {
338  // Loop high is implicit (defined by `upperBound()`) for random-access
339  // iterator, but we need to memorize posLo for linearization.
340  ret.push_back(posLo);
341  } else {
342  ret.push_back(posHi);
343  }
344  return ret;
345  };
346 
347  void deserialize(ValueRange vs) override {
348  assert(vs.size() == 2);
349  seek(vs.front());
350  if (randomAccessible())
351  posLo = vs.back();
352  else
353  posHi = vs.back();
354  };
355 
356  void genInitImpl(OpBuilder &b, Location l,
357  const SparseIterator *parent) override;
358 
359  ValuePair genForCond(OpBuilder &b, Location l) override {
360  if (randomAccessible())
361  return {deref(b, l), upperBound(b, l)};
362  return std::make_pair(getItPos(), posHi);
363  }
364 
365  Value genNotEndImpl(OpBuilder &b, Location l) override {
366  // We used the first level bound as the bound the collapsed set of levels.
367  return CMPI(ult, getItPos(), posHi);
368  }
369 
370  Value derefImpl(OpBuilder &b, Location l) override {
371  if (randomAccessible()) {
372  updateCrd(SUBI(getItPos(), posLo));
373  } else {
374  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
375  }
376  return getCrd();
377  };
378 
379  ValueRange forwardImpl(OpBuilder &b, Location l) override {
380  seek(ADDI(getItPos(), C_IDX(1)));
381  return getCursor();
382  }
383 
384  ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
385  Value curPos = getCursor().front();
386  Value nxPos = forward(b, l).front();
387  seek(SELECT(cond, nxPos, curPos));
388  return getCursor();
389  }
390 
391  void locateImpl(OpBuilder &b, Location l, Value crd) override {
392  assert(randomAccessible());
393  // Seek to the linearized position.
394  seek(ADDI(crd, posLo));
395  updateCrd(crd);
396  if (isBatchIterator()) {
397  // If this is a batch iterator, also update the batch coordinate.
398  assert(batchCrds.size() > lvl);
399  batchCrds[lvl] = crd;
400  }
401  }
402 
403  Value getItPos() const { return getCursor().front(); }
404  Value posLo, posHi;
405 };
406 
407 class DedupIterator : public ConcreteIterator {
408 private:
409  Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
410 
411 public:
412  DedupIterator(const SparseTensorLevel &stl)
413  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
414  assert(!stl.isUnique());
415  }
416  // For LLVM-style RTTI.
417  static bool classof(const SparseIterator *from) {
418  return from->kind == IterKind::kDedup;
419  }
420 
421  std::string getDebugInterfacePrefix() const override {
422  return std::string("dedup<") + stl.toString() + ">";
423  }
424  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
425  return {b.getIndexType(), b.getIndexType()};
426  }
427 
428  void genInitImpl(OpBuilder &b, Location l,
429  const SparseIterator *parent) override {
430  Value c0 = C_IDX(0);
431  ValueRange pPos = c0;
432 
433  // If the parent iterator is a batch iterator, we also start from 0 (but
434  // on a different batch).
435  if (parent && !parent->isBatchIterator())
436  pPos = parent->getCurPosition();
437 
438  Value posLo;
439  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
440  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
441 
442  seek({posLo, genSegmentHigh(b, l, posLo)});
443  }
444 
445  SmallVector<Value> serialize() const override {
446  SmallVector<Value> ret;
447  ret.append(getCursor().begin(), getCursor().end());
448  ret.push_back(posHi);
449  return ret;
450  };
451  void deserialize(ValueRange vs) override {
452  assert(vs.size() == 3);
453  seek(vs.take_front(getCursor().size()));
454  posHi = vs.back();
455  };
456 
457  Value genNotEndImpl(OpBuilder &b, Location l) override {
458  return CMPI(ult, getPos(), posHi);
459  }
460 
461  Value derefImpl(OpBuilder &b, Location l) override {
462  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
463  return getCrd();
464  };
465 
466  ValueRange forwardImpl(OpBuilder &b, Location l) override {
467  Value nxPos = getSegHi(); // forward the position to the next segment.
468  seek({nxPos, genSegmentHigh(b, l, nxPos)});
469  return getCursor();
470  }
471 
472  Value getPos() const { return getCursor()[0]; }
473  Value getSegHi() const { return getCursor()[1]; }
474 
475  Value posHi;
476 };
477 
478 // A util base-iterator that delegates all methods to the wrapped iterator.
479 class SimpleWrapIterator : public SparseIterator {
480 public:
481  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
482  unsigned extraCursorVal = 0)
483  : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
484 
485  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
486  return wrap->getCursorValTypes(b);
487  }
488  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
489  bool randomAccessible() const override { return wrap->randomAccessible(); };
490  bool iteratableByFor() const override { return wrap->iteratableByFor(); };
491 
492  SmallVector<Value> serialize() const override { return wrap->serialize(); };
493  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
494  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
495  void genInitImpl(OpBuilder &b, Location l,
496  const SparseIterator *parent) override {
497  wrap->genInit(b, l, parent);
498  }
499  Value genNotEndImpl(OpBuilder &b, Location l) override {
500  return wrap->genNotEndImpl(b, l);
501  }
502  ValueRange forwardImpl(OpBuilder &b, Location l) override {
503  return wrap->forward(b, l);
504  };
505  Value upperBound(OpBuilder &b, Location l) const override {
506  return wrap->upperBound(b, l);
507  };
508 
509  Value derefImpl(OpBuilder &b, Location l) override {
510  return wrap->derefImpl(b, l);
511  }
512 
513  void locateImpl(OpBuilder &b, Location l, Value crd) override {
514  return wrap->locate(b, l, crd);
515  }
516 
517  SparseIterator &getWrappedIterator() const { return *wrap; }
518 
519 protected:
520  std::unique_ptr<SparseIterator> wrap;
521 };
522 
523 //
524 // A filter iterator wrapped from another iterator. The filter iterator update
525 // the wrapped iterator *in-place*.
526 //
527 class FilterIterator : public SimpleWrapIterator {
528  // Coorindate translation between crd loaded from the wrap iterator and the
529  // filter iterator.
530  Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
531  // crd = (wrapCrd - offset) / stride
532  return DIVUI(SUBI(wrapCrd, offset), stride);
533  }
534  Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
535  // wrapCrd = crd * stride + offset
536  return ADDI(MULI(crd, stride), offset);
537  }
538 
539  Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
540 
541  Value genShouldFilter(OpBuilder &b, Location l);
542 
543 public:
544  // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
545  // when crd always < size.
546  FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
547  Value stride, Value size)
548  : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
549  stride(stride), size(size) {}
550 
551  // For LLVM-style RTTI.
552  static bool classof(const SparseIterator *from) {
553  return from->kind == IterKind::kFilter;
554  }
555 
556  std::string getDebugInterfacePrefix() const override {
557  return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
558  }
559 
560  bool iteratableByFor() const override { return randomAccessible(); };
561  Value upperBound(OpBuilder &b, Location l) const override { return size; };
562 
563  void genInitImpl(OpBuilder &b, Location l,
564  const SparseIterator *parent) override {
565  wrap->genInit(b, l, parent);
566  if (!randomAccessible()) {
567  // TODO: we can skip this when stride == 1 and offset == 0, we can also
568  // use binary search here.
569  forwardIf(b, l, genShouldFilter(b, l));
570  } else {
571  // Else, locate to the slice.offset, which is the first coordinate
572  // included by the slice.
573  wrap->locate(b, l, offset);
574  }
575  }
576 
577  Value genNotEndImpl(OpBuilder &b, Location l) override;
578 
579  Value derefImpl(OpBuilder &b, Location l) override {
580  updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
581  return getCrd();
582  }
583 
584  void locateImpl(OpBuilder &b, Location l, Value crd) override {
585  assert(randomAccessible());
586  wrap->locate(b, l, toWrapCrd(b, l, crd));
587  updateCrd(crd);
588  }
589 
590  ValueRange forwardImpl(OpBuilder &b, Location l) override;
591 
592  Value offset, stride, size;
593 };
594 
595 //
596 // A pad iterator wrapped from another iterator. The pad iterator updates
597 // the wrapped iterator *in-place*.
598 //
599 class PadIterator : public SimpleWrapIterator {
600 
601 public:
602  PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
603  Value padHigh)
604  : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
605  wrap->randomAccessible() ? 1 : 0),
606  padLow(padLow), padHigh(padHigh) {}
607 
608  // For LLVM-style RTTI.
609  static bool classof(const SparseIterator *from) {
610  return from->kind == IterKind::kPad;
611  }
612 
613  std::string getDebugInterfacePrefix() const override {
614  return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
615  }
616 
617  // Returns a pair of values for *upper*, *lower* bound respectively.
618  ValuePair genForCond(OpBuilder &b, Location l) override {
619  if (randomAccessible())
620  return {getCrd(), upperBound(b, l)};
621  return wrap->genForCond(b, l);
622  }
623 
624  // For padded dense iterator, we append a `inPadZone: bool` in addition to
625  // values used by the wrapped iterator.
626  ValueRange getCurPosition() const override { return getCursor(); }
627 
628  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
629  SmallVector<Type> ret = wrap->getCursorValTypes(b);
630  // Need an extra boolean value `inPadZone` for padded dense iterator.
631  if (randomAccessible())
632  ret.push_back(b.getI1Type());
633 
634  return ret;
635  }
636 
637  // The upper bound after padding becomes `size + padLow + padHigh`.
638  Value upperBound(OpBuilder &b, Location l) const override {
639  return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
640  };
641 
642  // The pad_coord = coord + pad_lo
643  Value derefImpl(OpBuilder &b, Location l) override {
644  updateCrd(ADDI(wrap->deref(b, l), padLow));
645  return getCrd();
646  }
647 
648  void locateImpl(OpBuilder &b, Location l, Value crd) override {
649  assert(randomAccessible());
650  wrap->locate(b, l, SUBI(crd, padLow));
651 
652  // inPadZone = crd < padLow || crd >= size + padLow.
653  Value inPadLow = CMPI(ult, crd, padLow);
654  Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
655  getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
656 
657  updateCrd(crd);
658  }
659 
660  Value padLow, padHigh;
661 };
662 
663 class NonEmptySubSectIterator : public SparseIterator {
664 public:
665  using TraverseBuilder = llvm::function_ref<scf::ValueVector(
667 
668  NonEmptySubSectIterator(OpBuilder &b, Location l,
669  const SparseIterator *parent,
670  std::unique_ptr<SparseIterator> &&delegate,
671  Value subSectSz)
672  : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
673  parent(parent), delegate(std::move(delegate)),
674  tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
675  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
676  if (p == nullptr) {
677  // Extract subsections along the root level.
678  maxTupleCnt = C_IDX(1);
679  } else if (p->lvl == lvl) {
680  // Extract subsections along the same level.
681  maxTupleCnt = p->maxTupleCnt;
682  assert(false && "Not implemented.");
683  } else {
684  // Extract subsections along the previous level.
685  assert(p->lvl + 1 == lvl);
686  maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
687  }
688  // We don't need an extra buffer to find subsections on random-accessible
689  // levels.
690  if (randomAccessible())
691  return;
692  subSectPosBuf = allocSubSectPosBuf(b, l);
693  }
694 
695  // For LLVM-style RTTI.
696  static bool classof(const SparseIterator *from) {
697  return from->kind == IterKind::kNonEmptySubSect;
698  }
699 
700  std::string getDebugInterfacePrefix() const override {
701  return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
702  }
703  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
704  // minCrd, absolute offset, notEnd
705  return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
706  }
707 
708  // The sliced pointer buffer is organized as:
709  // [[itVal0, itVal1, ..., pNx0],
710  // [itVal0, itVal1, ..., pNx0],
711  // ...]
712  Value allocSubSectPosBuf(OpBuilder &b, Location l) {
713  return b.create<memref::AllocaOp>(
714  l,
715  MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
716  maxTupleCnt);
717  }
718 
719  void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
720  Value start) const {
721  b.create<memref::StoreOp>(l, start, subSectPosBuf,
722  ValueRange{tupleId, C_IDX(tupleSz)});
723  }
724 
725  Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
726  return b.create<memref::LoadOp>(l, subSectPosBuf,
727  ValueRange{tupleId, C_IDX(tupleSz)});
728  }
729 
730  void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
731  ValueRange itVals) const {
732  assert(itVals.size() == tupleSz);
733  for (unsigned i = 0; i < tupleSz; i++) {
734  b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
735  ValueRange{tupleId, C_IDX(i)});
736  }
737  }
738 
739  SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
740  Value tupleId) const {
741  SmallVector<Value> ret;
742  for (unsigned i = 0; i < tupleSz; i++) {
743  Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
744  ValueRange{tupleId, C_IDX(i)});
745  ret.push_back(v);
746  }
747  return ret;
748  }
749 
750  bool isSubSectRoot() const {
751  return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
752  }
753 
754  // Generate code that inflate the current subsection tree till the current
755  // level such that every leaf node is visited.
756  ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
757  TraverseBuilder builder) const;
758 
759  bool isBatchIterator() const override { return delegate->isBatchIterator(); }
760  bool randomAccessible() const override {
761  return delegate->randomAccessible();
762  };
763  bool iteratableByFor() const override { return randomAccessible(); };
764  Value upperBound(OpBuilder &b, Location l) const override {
765  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
766  Value parentUB =
767  p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
768  return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
769  };
770 
771  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
772 
773  void locateImpl(OpBuilder &b, Location l, Value crd) override {
774  Value absOff = crd;
775 
776  if (isSubSectRoot())
777  delegate->locate(b, l, absOff);
778  else
779  assert(parent->lvl + 1 == lvl);
780 
781  seek(ValueRange{absOff, absOff, C_TRUE});
782  updateCrd(crd);
783  }
784 
785  Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
786  return SUBI(wrapCrd, getAbsOff());
787  }
788 
789  Value genNotEndImpl(OpBuilder &b, Location l) override {
790  return getNotEnd();
791  };
792 
793  Value derefImpl(OpBuilder &b, Location l) override {
794  // Use the relative offset to coiterate.
795  Value crd;
796  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
797  if (p && p->lvl == lvl)
798  crd = SUBI(getAbsOff(), p->getAbsOff());
799  crd = getAbsOff();
800 
801  updateCrd(crd);
802  return crd;
803  };
804 
805  ValueRange forwardImpl(OpBuilder &b, Location l) override;
806 
807  Value getMinCrd() const { return subSectMeta[0]; }
808  Value getAbsOff() const { return subSectMeta[1]; }
809  Value getNotEnd() const { return subSectMeta[2]; }
810 
811  const SparseIterator *parent;
812  std::unique_ptr<SparseIterator> delegate;
813 
814  // Number of values required to serialize the wrapped iterator.
815  const unsigned tupleSz;
816  // Max number of tuples, and the actual number of tuple.
817  Value maxTupleCnt, tupleCnt;
818  // The memory used to cache the tuple serialized from the wrapped iterator.
819  Value subSectPosBuf;
820 
821  const Value subSectSz;
822 
823  // minCrd, absolute offset, notEnd
824  SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
825 };
826 
827 class SubSectIterator;
828 
829 // A wrapper that helps generating code to traverse a subsection, used
830 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
831 struct SubSectIterHelper {
832  explicit SubSectIterHelper(const SubSectIterator &iter);
833  explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
834 
835  // Delegate methods.
836  void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
837  void locate(OpBuilder &b, Location l, Value crd);
838  Value genNotEnd(OpBuilder &b, Location l);
839  Value deref(OpBuilder &b, Location l);
840  ValueRange forward(OpBuilder &b, Location l);
841 
842  const NonEmptySubSectIterator &subSect;
844 };
845 
846 class SubSectIterator : public SparseIterator {
847 public:
848  SubSectIterator(const NonEmptySubSectIterator &subSect,
849  const SparseIterator &parent,
850  std::unique_ptr<SparseIterator> &&wrap)
852  /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
853  subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
854  assert(subSect.tid == tid && subSect.lvl == lvl);
855  assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
856  };
857 
858  // For LLVM-style RTTI.
859  static bool classof(const SparseIterator *from) {
860  return from->kind == IterKind::kSubSect;
861  }
862 
863  std::string getDebugInterfacePrefix() const override {
864  return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
865  }
866  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
867  SmallVector<Type> ret = wrap->getCursorValTypes(b);
868  if (!randomAccessible())
869  ret.push_back(b.getIndexType()); // The extra counter.
870  return ret;
871  }
872 
873  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
874  bool randomAccessible() const override { return wrap->randomAccessible(); };
875  bool iteratableByFor() const override { return randomAccessible(); };
876  Value upperBound(OpBuilder &b, Location l) const override {
877  return subSect.subSectSz;
878  }
879 
880  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
881 
882  Value getNxLvlTupleId(OpBuilder &b, Location l) const {
883  if (randomAccessible()) {
884  return ADDI(getCrd(), nxLvlTupleStart);
885  };
886  return ADDI(getCursor().back(), nxLvlTupleStart);
887  }
888 
889  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
890  if (randomAccessible()) {
891  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
892  assert(p->lvl + 1 == lvl);
893  wrap->genInit(b, l, p);
894  // Linearize the dense subsection index.
895  nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
896  } else {
897  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
898  wrap->deserialize(subSect.delegate->serialize());
899  nxLvlTupleStart = C_IDX(0);
900  }
901  return;
902  }
903  assert(!randomAccessible());
904  assert(getCursor().size() == wrap->getCursor().size() + 1);
905  // Extra counter that counts the number of actually visited coordinates in
906  // the sparse subsection.
907  getMutCursorVals().back() = C_IDX(0);
908  Value tupleId;
909  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
910  assert(p->lvl + 1 == lvl);
911  tupleId = p->getNxLvlTupleId(b, l);
912  } else {
913  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
914  tupleId = C_IDX(0);
915  }
916  nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
917  helper.deserializeFromTupleId(b, l, tupleId);
918  }
919 
920  void locateImpl(OpBuilder &b, Location l, Value crd) override {
921  helper.locate(b, l, crd);
922  updateCrd(crd);
923  }
924 
925  Value genNotEndImpl(OpBuilder &b, Location l) override {
926  return helper.genNotEnd(b, l);
927  }
928 
929  Value derefImpl(OpBuilder &b, Location l) override {
930  Value crd = helper.deref(b, l);
931  updateCrd(crd);
932  return crd;
933  };
934 
935  ValueRange forwardImpl(OpBuilder &b, Location l) override {
936  helper.forward(b, l);
937  assert(!randomAccessible());
938  assert(getCursor().size() == wrap->getCursor().size() + 1);
939  getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
940  return getCursor();
941  };
942 
943  Value nxLvlTupleStart;
944 
945  const NonEmptySubSectIterator &subSect;
946  std::unique_ptr<SparseIterator> wrap;
947  const SparseIterator &parent;
948 
949  SubSectIterHelper helper;
950 };
951 
952 } // namespace
953 
954 //===----------------------------------------------------------------------===//
955 // SparseIterator derived classes implementation.
956 //===----------------------------------------------------------------------===//
957 
959  const SparseIterator *p) {
961  std::string prefix = getDebugInterfacePrefix();
962  Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
963  getCursorValTypes(b));
964  seek(begin->getResults());
965  return;
966  }
967  // Inherent batch coordinates from parents.
968  if (p)
969  inherentBatch(*p);
970  // TODO: support lowering to function call.
971  return genInitImpl(b, l, p);
972 }
973 
976  std::string prefix = getDebugInterfacePrefix();
977  Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
978  getCursor(), b.getI1Type());
979  return notEnd->getResult(0);
980  }
981  // TODO: support lowering to function call.
982  return genNotEndImpl(b, l);
983 }
984 
987  std::string prefix = getDebugInterfacePrefix();
988  SmallVector<Value> args = getCursor();
989  args.push_back(crd);
990  Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
991  getCursorValTypes(b));
992  seek(locate->getResults());
993  updateCrd(crd);
994  return;
995  }
996  return locateImpl(b, l, crd);
997 }
998 
1001  std::string prefix = getDebugInterfacePrefix();
1002  SmallVector<Value> args = getCursor();
1003  Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1004  getCursor(), b.getIndexType());
1005  updateCrd(deref->getResult(0));
1006  return getCrd();
1007  }
1008  return derefImpl(b, l);
1009 }
1010 
1012  assert(!randomAccessible());
1014  std::string prefix = getDebugInterfacePrefix();
1015  Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1017  seek(next->getResults());
1018  return getCursor();
1019  }
1020  return forwardImpl(b, l);
1021 }
1022 
1024  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
1025  // Generate else branch first, otherwise iterator values will be updated by
1026  // `forward()`.
1027  b.setInsertionPointToStart(ifOp.elseBlock());
1028  YIELD(getCursor());
1029 
1030  b.setInsertionPointToStart(ifOp.thenBlock());
1031  YIELD(forward(b, l));
1032 
1033  b.setInsertionPointAfter(ifOp);
1034  seek(ifOp.getResults());
1035  return getCursor();
1036 }
1037 
1038 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1039  auto whileOp = b.create<scf::WhileOp>(
1040  l, pos.getType(), pos,
1041  /*beforeBuilder=*/
1042  [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1043  Value inBound = CMPI(ult, ivs.front(), posHi);
1044  auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
1045  {
1046  OpBuilder::InsertionGuard guard(b);
1047  // If in bound, load the next coordinates and check duplication.
1048  b.setInsertionPointToStart(ifInBound.thenBlock());
1049  Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1050  Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1051  Value isDup = CMPI(eq, headCrd, tailCrd);
1052  YIELD(isDup);
1053  // Else, the position is out of bound, yield false.
1054  b.setInsertionPointToStart(ifInBound.elseBlock());
1055  YIELD(constantI1(b, l, false));
1056  }
1057  b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1058  },
1059  /*afterBuilder=*/
1060  [](OpBuilder &b, Location l, ValueRange ivs) {
1061  Value nxPos = ADDI(ivs[0], C_IDX(1));
1062  YIELD(nxPos);
1063  });
1064  // Return the segment high.
1065  return whileOp.getResult(0);
1066 }
1067 
1068 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1069  Value wrapCrd) {
1070  Value crd = fromWrapCrd(b, l, wrapCrd);
1071  // Test whether the coordinate is on stride.
1072  Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1073  // Test wrapCrd < offset
1074  notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1075  // Test crd >= length
1076  notlegit = ORI(CMPI(uge, crd, size), notlegit);
1077  return notlegit;
1078 }
1079 
1080 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1081  auto r = genWhenInBound(
1082  b, l, *wrap, C_FALSE,
1083  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1084  Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1085  return {notLegit};
1086  });
1087 
1088  assert(r.size() == 1);
1089  return r.front();
1090 }
1091 
1092 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1093  assert(!wrap->randomAccessible());
1094  auto r = genWhenInBound(
1095  b, l, *wrap, C_FALSE,
1096  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1097  Value crd = fromWrapCrd(b, l, wrapCrd);
1098  // crd < size
1099  return {CMPI(ult, crd, size)};
1100  });
1101  assert(r.size() == 1);
1102  return r.front();
1103 }
1104 
1105 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1106  assert(!randomAccessible());
1107  // Generates
1108  //
1109  // bool isFirst = true;
1110  // while !it.end() && (!legit(*it) || isFirst)
1111  // wrap ++;
1112  // isFirst = false;
1113  //
1114  // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1115  // flag here because `wrap++` might have a complex implementation (e.g., to
1116  // forward a subsection).
1117  Value isFirst = constantI1(b, l, true);
1118 
1119  SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1120  whileArgs.push_back(isFirst);
1121  auto whileOp = b.create<scf::WhileOp>(
1122  l, ValueRange(whileArgs).getTypes(), whileArgs,
1123  /*beforeBuilder=*/
1124  [this](OpBuilder &b, Location l, ValueRange ivs) {
1125  ValueRange isFirst = linkNewScope(ivs);
1126  assert(isFirst.size() == 1);
1127  scf::ValueVector cont =
1128  genWhenInBound(b, l, *wrap, C_FALSE,
1129  [this, isFirst](OpBuilder &b, Location l,
1130  Value wrapCrd) -> scf::ValueVector {
1131  // crd < size && !legit();
1132  Value notLegit =
1133  genCrdNotLegitPredicate(b, l, wrapCrd);
1134  Value crd = fromWrapCrd(b, l, wrapCrd);
1135  Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1136  ret = ORI(ret, isFirst.front());
1137  return {ret};
1138  });
1139  b.create<scf::ConditionOp>(l, cont.front(), ivs);
1140  },
1141  /*afterBuilder=*/
1142  [this](OpBuilder &b, Location l, ValueRange ivs) {
1143  linkNewScope(ivs);
1144  wrap->forward(b, l);
1145  SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1146  yieldVals.push_back(constantI1(b, l, false));
1147  YIELD(yieldVals);
1148  });
1149 
1150  b.setInsertionPointAfter(whileOp);
1151  linkNewScope(whileOp.getResults());
1152  return getCursor();
1153 }
1154 
1155 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1156  : subSect(subSect), wrap(*subSect.delegate) {}
1157 
1158 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1159  : subSect(iter.subSect), wrap(*iter.wrap) {}
1160 
1161 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1162  Value tupleId) {
1163  assert(!subSect.randomAccessible());
1164  wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1165 }
1166 
1167 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1168  Value absCrd = ADDI(crd, subSect.getAbsOff());
1169  wrap.locate(b, l, absCrd);
1170 }
1171 
1172 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1173  assert(!wrap.randomAccessible());
1174  auto r = genWhenInBound(
1175  b, l, wrap, C_FALSE,
1176  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1177  Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1178  // crd < size
1179  return {CMPI(ult, crd, subSect.subSectSz)};
1180  });
1181  assert(r.size() == 1);
1182  return r.front();
1183 }
1184 
1185 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1186  Value wrapCrd = wrap.deref(b, l);
1187  Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1188  return crd;
1189 }
1190 
1191 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1192  return wrap.forward(b, l);
1193 }
1194 
1195 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1196  OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1197  // Set up the helper to help traverse a sparse subsection.
1198  SubSectIterHelper helper(*this);
1199  if (!randomAccessible()) {
1200  // The subsection tree have been expanded till the level and cached,
1201  // traverse all the leaves and expanded to the next level.
1202  SmallVector<Value> iterArgs;
1203  iterArgs.push_back(C_IDX(0));
1204  iterArgs.append(reduc.begin(), reduc.end());
1205  auto forEachLeaf = b.create<scf::ForOp>(
1206  l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1207  [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1208  ValueRange iterArgs) {
1209  // Deserialize the iterator at the cached position (tupleId).
1210  helper.deserializeFromTupleId(b, l, tupleId);
1211 
1212  Value cnt = iterArgs.front();
1213  // Record the number of leaf nodes included in the subsection.
1214  // The number indicates the starting tupleId for the next level that
1215  // is corresponding to the current node.
1216  helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1217 
1218  SmallVector<Value> whileArgs(helper.wrap.getCursor());
1219  whileArgs.append(iterArgs.begin(), iterArgs.end());
1220 
1221  auto whileOp = b.create<scf::WhileOp>(
1222  l, ValueRange(whileArgs).getTypes(), whileArgs,
1223  /*beforeBuilder=*/
1224  [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1225  helper.wrap.linkNewScope(ivs);
1226  b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1227  },
1228  /*afterBuilder=*/
1229  [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1230  ValueRange remIter = helper.wrap.linkNewScope(ivs);
1231  Value cnt = remIter.front();
1232  ValueRange userIter = remIter.drop_front();
1233  scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1234 
1235  SmallVector<Value> nxIter = helper.forward(b, l);
1236  nxIter.push_back(ADDI(cnt, C_IDX(1)));
1237  nxIter.append(userNx.begin(), userNx.end());
1238  YIELD(nxIter);
1239  });
1240  ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1241  YIELD(res);
1242  });
1243  return forEachLeaf.getResults().drop_front();
1244  }
1245 
1246  assert(randomAccessible());
1247  // Helper lambda that traverse the current dense subsection range.
1248  auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1249  const SparseIterator *parent,
1250  ValueRange reduc) {
1251  assert(!parent || parent->lvl + 1 == lvl);
1252  delegate->genInit(b, l, parent);
1253  auto forOp = b.create<scf::ForOp>(
1254  l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1255  [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1256  helper.locate(b, l, crd);
1257  scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1258  YIELD(nx);
1259  });
1260  return forOp.getResults();
1261  };
1262 
1263  if (isSubSectRoot()) {
1264  return visitDenseSubSect(b, l, parent, reduc);
1265  }
1266  // Else, this is not the root, recurse until root.
1267  auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1268  assert(p->lvl + 1 == lvl);
1269  return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1270 }
1271 
1272 void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1273  const SparseIterator *parent) {
1274 
1275  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1276  batchCrds.resize(stl.lvl + 1, nullptr);
1277 
1278  Value c0 = C_IDX(0);
1279  ValueRange pPos = c0;
1280  Value inPadZone = nullptr;
1281  // If the parent iterator is a batch iterator, we also start from 0 (but
1282  // on a different batch).
1283  if (parent && !parent->isBatchIterator()) {
1284  pPos = parent->getCurPosition();
1285  if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1286  // A padded dense iterator create "sparse" padded zone, which need to be
1287  // handled specially.
1288  inPadZone = pPos.back();
1289  pPos = pPos.drop_back();
1290  }
1291  }
1292 
1293  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1294  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1295  // Seek to the lowest position.
1296  seek(posLo);
1297 }
1298 
1299 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1300  const SparseIterator *) {
1301  Value c0 = C_IDX(0);
1302  if (!isSubSectRoot()) {
1303  assert(parent->lvl + 1 == lvl);
1304  if (randomAccessible()) {
1305  // We can not call wrap->genInit() here to initialize the wrapped
1306  // iterator, because the parent of the curent iterator is still
1307  // unresolved.
1308  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1309  return;
1310  }
1311 
1312  auto *p = cast<NonEmptySubSectIterator>(parent);
1313  SmallVector<Value, 3> reduc = {
1314  C_IDX(-1), // minCrd (max signless integer)
1315  c0, // tupleId
1316  };
1317 
1318  // Expand the subsection tree from the parent level to the current level.
1319  ValueRange result = p->inflateSubSectTree(
1320  b, l, reduc,
1321  [this](OpBuilder &b, Location l, const SparseIterator *parent,
1322  ValueRange reduc) -> scf::ValueVector {
1323  assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1324  Value minCrd = reduc.front();
1325  Value tupleId = reduc.back();
1326 
1327  // Initialize the subsection range.
1328  SubSectIterHelper helper(*this);
1329  helper.wrap.genInit(b, l, parent);
1330 
1331  // Update minCrd.
1332  minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1333  [minCrd](OpBuilder &b, Location l,
1334  Value crd) -> scf::ValueVector {
1335  Value min = MINUI(crd, minCrd);
1336  return {min};
1337  })
1338  .front();
1339 
1340  // Cache the sparse range.
1341  storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1342  tupleId = ADDI(tupleId, C_IDX(1));
1343  return {minCrd, tupleId};
1344  });
1345  assert(result.size() == 2);
1346  tupleCnt = result.back();
1347 
1348  Value minCrd = result.front();
1349  Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1350  Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1351  seek({minCrd, absOff, notEnd});
1352  return;
1353  }
1354 
1355  // This is the root level of the subsection, which means that it is resolved
1356  // to one node.
1357  assert(isSubSectRoot());
1358 
1359  // Initialize the position, the position marks the *lower bound* of the
1360  // subRange. The higher bound is determined by the size of the subsection.
1361  delegate->genInit(b, l, parent);
1362  if (randomAccessible()) {
1363  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1364  return;
1365  }
1366 
1367  // Only have one root node.
1368  tupleCnt = C_IDX(1);
1369  // Cache the sparse range.
1370  storeCursorVals(b, l, c0, delegate->serialize());
1371  SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1372  auto meta = genWhenInBound(
1373  b, l, *delegate, elseRet,
1374  [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1375  Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1376  return {crd, offset, C_TRUE};
1377  });
1378 
1379  seek(meta);
1380 }
1381 
1382 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1383  assert(!randomAccessible());
1384  Value c0 = C_IDX(0), c1 = C_IDX(1);
1385  // Forward to the next non empty slice by generating
1386  //
1387  // if (minCrd > offset) {
1388  // offset += 1
1389  // } else {
1390  // minCrd = nextMinInSlice();
1391  // offset = minCrd - size + 1;
1392  // }
1393  //
1394  // if (offset + size > parents.size)
1395  // isNonEmpty = false;
1396  Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1397  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1398  {
1399  OpBuilder::InsertionGuard guard(b);
1400  // Take the fast path
1401  // if (minCrd > offset)
1402  // offset += 1
1403  b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1404  Value nxOffset = ADDI(getAbsOff(), c1);
1405  YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1406 
1407  // else /*minCrd == offset*/ {
1408  // for (i = 0; i < tupleCnt; i++) {
1409  // wrap->deserialize(pos[i]);
1410  // minCrd=min(minCrd, *wrap);
1411  // }
1412  // offset = minCrd - size + 1;
1413  // }
1414  b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1415  SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1416  C_FALSE}; // isNotEnd
1417  auto loopNest = scf::buildLoopNest(
1418  b, l, c0, tupleCnt, c1, loopArgs,
1419  [this](OpBuilder &b, Location l, ValueRange ivs,
1420  ValueRange iterArgs) -> scf::ValueVector {
1421  Value tupleId = ivs.front();
1422  SubSectIterHelper helper(*this);
1423  helper.deserializeFromTupleId(b, l, tupleId);
1424 
1425  return genWhenInBound(
1426  b, l, *delegate, /*elseRet=*/iterArgs,
1427  [this, iterArgs, tupleId](OpBuilder &b, Location l,
1428  Value crd) -> scf::ValueVector {
1429  // if coord == minCrd
1430  // wrap->forward();
1431  Value isMin = CMPI(eq, crd, getMinCrd());
1432  delegate->forwardIf(b, l, isMin);
1433  // Update the forwarded iterator values if needed.
1434  auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1435  b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1436  storeCursorVals(b, l, tupleId, delegate->serialize());
1437  b.setInsertionPointAfter(ifIsMin);
1438  // if (!wrap.end())
1439  // yield(min(nxMinCrd, *wrap), true)
1440  Value nxMin = iterArgs[0];
1441  return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1442  [nxMin](OpBuilder &b, Location l,
1443  Value crd) -> scf::ValueVector {
1444  Value nx = b.create<arith::MinUIOp>(
1445  l, crd, nxMin);
1446  return {nx, C_TRUE};
1447  });
1448  });
1449  });
1450 
1451  scf::ForOp forOp = loopNest.loops.front();
1452  b.setInsertionPointAfter(forOp);
1453 
1454  Value nxMinCrd = forOp.getResult(0);
1455  Value nxNotEnd = forOp.getResult(1);
1456  Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1457  YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1458  }
1459 
1460  Value nxMinCrd = ifOp.getResult(0);
1461  Value nxAbsOff = ifOp.getResult(1);
1462  Value nxNotEnd = ifOp.getResult(2);
1463 
1464  // We should at least forward the offset by one.
1465  Value minAbsOff = ADDI(getAbsOff(), c1);
1466  nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1467 
1468  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1469  // The coordinate should not exceeds the space upper bound.
1470  Value crd = deref(b, l);
1471  nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1472 
1473  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1474  return getCursor();
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // SparseIterator factory functions.
1479 //===----------------------------------------------------------------------===//
1480 
1481 std::unique_ptr<SparseTensorLevel>
1483  unsigned tid, Level lvl) {
1484  auto stt = getSparseTensorType(t);
1485 
1486  LevelType lt = stt.getLvlType(lvl);
1487  Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1488  : b.create<tensor::DimOp>(l, t, lvl).getResult();
1489 
1490  switch (lt.getLvlFmt()) {
1491  case LevelFormat::Dense:
1492  return std::make_unique<DenseLevel>(tid, lvl, sz);
1493  case LevelFormat::Batch:
1494  return std::make_unique<BatchLevel>(tid, lvl, sz);
1495  case LevelFormat::Compressed: {
1496  Value pos = b.create<ToPositionsOp>(l, t, lvl);
1497  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1498  return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
1499  }
1500  case LevelFormat::LooseCompressed: {
1501  Value pos = b.create<ToPositionsOp>(l, t, lvl);
1502  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1503  return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
1504  }
1505  case LevelFormat::Singleton: {
1506  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1507  return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
1508  }
1509  case LevelFormat::NOutOfM: {
1510  Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
1511  return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
1512  }
1513  case LevelFormat::Undef:
1514  llvm_unreachable("undefined level format");
1515  }
1516  llvm_unreachable("unrecognizable level format");
1517 }
1518 
1519 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1520 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1521  SparseEmitStrategy strategy) {
1522  auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1523  auto it = std::make_unique<TrivialIterator>(*stl);
1524  it->setSparseEmitStrategy(strategy);
1525  return std::make_pair(std::move(stl), std::move(it));
1526 }
1527 
1528 std::unique_ptr<SparseIterator>
1530  SparseEmitStrategy strategy) {
1531  std::unique_ptr<SparseIterator> ret;
1532  if (!isUniqueLT(stl.getLT())) {
1533  // We always dedupliate the non-unique level, but we should optimize it away
1534  // if possible.
1535  ret = std::make_unique<DedupIterator>(stl);
1536  } else {
1537  ret = std::make_unique<TrivialIterator>(stl);
1538  }
1539  ret->setSparseEmitStrategy(strategy);
1540  return ret;
1541 }
1542 
1543 std::unique_ptr<SparseIterator>
1544 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1545  Value offset, Value stride, Value size,
1546  SparseEmitStrategy strategy) {
1547 
1548  auto ret =
1549  std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1550  ret->setSparseEmitStrategy(strategy);
1551  return ret;
1552 }
1553 
1554 std::unique_ptr<SparseIterator>
1555 sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1556  Value padLow, Value padHigh,
1557  SparseEmitStrategy strategy) {
1558  auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1559  ret->setSparseEmitStrategy(strategy);
1560  return ret;
1561 }
1562 
1564  auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1565  if (filter)
1566  return &filter->getWrappedIterator();
1567  return it;
1568 }
1569 
1570 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1571  OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1572  std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1573  SparseEmitStrategy strategy) {
1574 
1575  // Try unwrap the NonEmptySubSectIterator from a filter parent.
1576  parent = tryUnwrapFilter(parent);
1577  std::unique_ptr<SparseIterator> it =
1578  std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1579  std::move(delegate), size);
1580 
1581  if (stride != 1) {
1582  // TODO: We can safely skip bound checking on sparse levels, but for dense
1583  // iteration space, we need the bound to infer the dense loop range.
1584  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1585  C_IDX(stride), /*size=*/loopBound);
1586  }
1587  it->setSparseEmitStrategy(strategy);
1588  return it;
1589 }
1590 
1591 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1592  OpBuilder &b, Location l, const SparseIterator &subSectIter,
1593  const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1594  Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1595 
1596  // This must be a subsection iterator or a filtered subsection iterator.
1597  auto &subSect =
1598  llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1599 
1600  std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1601  subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1602 
1603  if (stride != 1) {
1604  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1605  C_IDX(stride), /*size=*/loopBound);
1606  }
1607  it->setSparseEmitStrategy(strategy);
1608  return it;
1609 }
1610 
1611 #undef CMPI
1612 #undef C_IDX
1613 #undef YIELD
1614 #undef ADDI
1615 #undef ANDI
1616 #undef SUBI
1617 #undef MULI
1618 #undef SELECT
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:114
#define SELECT(c, lhs, rhs)
#define C_FALSE
#define SUBI(lhs, rhs)
#define MULI(lhs, rhs)
#define C_IDX(v)
#define ANDI(lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
#define C_TRUE
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
#define YIELD(vs)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
#define ORI(lhs, rhs)
#define DIVUI(lhs, rhs)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
#define ADDI(lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
result_range getResults()
Definition: Operation.h:410
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool isUniqueLT(LevelType lt)
Definition: Enums.h:424
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
uint64_t getN(LevelType lt)
Definition: Enums.h:438
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:359
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::unique_ptr< SparseIterator > makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy)
Helper function to create a simple SparseIterator object that iterate over the SparseTensorLevel.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context)
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:51
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LoopVector loops
Definition: SCF.h:73
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
Definition: Enums.h:320
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326