MLIR  20.0.0git
SparseTensorIterator.cpp
Go to the documentation of this file.
1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11 
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20 
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs) \
25  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26  .getResult())
27 
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41  (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42 
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51  // It is either an array of size 2 or size 1 depending on whether the sparse
52  // level requires a position array.
53  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54  std::array<Value, 1>>;
55 
56 public:
57  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58  BufferT buffers)
59  : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60 
61  ValueRange getLvlBuffers() const override { return buffers; }
62 
63  Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64  Value iv) const override {
65  SmallVector<Value> memCrd(batchPrefix);
66  memCrd.push_back(iv);
67  return genIndexLoad(b, l, getCrdBuf(), memCrd);
68  }
69 
70 protected:
71  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72  Value getPosBuf() const {
73  return buffers[0];
74  }
75 
76  Value getCrdBuf() const {
77  if constexpr (hasPosBuffer)
78  return buffers[1];
79  else
80  return buffers[0];
81  }
82 
83  const BufferT buffers;
84 };
85 
86 class DenseLevel : public SparseTensorLevel {
87 public:
88  DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89  : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90 
91  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92  llvm_unreachable("locate random-accessible level instead");
93  }
94 
95  ValueRange getLvlBuffers() const override { return {}; }
96 
97  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98  ValueRange parentPos, Value inPadZone) const override {
99  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100  assert(!inPadZone && "Not implemented");
101  Value p = parentPos.front();
102  Value posLo = MULI(p, lvlSize);
103  return {posLo, lvlSize};
104  }
105 };
106 
107 class BatchLevel : public SparseTensorLevel {
108 public:
109  BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110  : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111 
112  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113  llvm_unreachable("locate random-accessible level instead");
114  }
115 
116  ValueRange getLvlBuffers() const override { return {}; }
117 
118  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119  ValueRange parentPos, Value inPadZone) const override {
120  assert(!inPadZone && "Not implemented");
121  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122  // No need to linearize the position for non-annotated tensors.
123  return {C_IDX(0), lvlSize};
124  }
125 };
126 
127 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128 public:
129  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130  Value posBuffer, Value crdBuffer)
131  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132 
133  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134  ValueRange parentPos, Value inPadZone) const override {
135 
136  assert(parentPos.size() == 1 &&
137  "compressed level must be the first non-unique level.");
138 
139  auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140  Value p = parentPos.front();
141  SmallVector<Value> memCrd(batchPrefix);
142  memCrd.push_back(p);
143  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144  memCrd.back() = ADDI(p, C_IDX(1));
145  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146  return {pLo, pHi};
147  };
148 
149  if (inPadZone == nullptr)
150  return loadRange();
151 
153  scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154  // True branch, returns a "fake" empty range [0, 0) if parent
155  // iterator is in pad zone.
156  b.setInsertionPointToStart(posRangeIf.thenBlock());
157 
158  SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159  b.create<scf::YieldOp>(l, emptyRange);
160 
161  // False branch, returns the actual range.
162  b.setInsertionPointToStart(posRangeIf.elseBlock());
163  auto [pLo, pHi] = loadRange();
164  SmallVector<Value, 2> loadedRange{pLo, pHi};
165  b.create<scf::YieldOp>(l, loadedRange);
166 
167  b.setInsertionPointAfter(posRangeIf);
168  ValueRange posRange = posRangeIf.getResults();
169  return {posRange.front(), posRange.back()};
170  }
171 }; // namespace
172 
173 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174 public:
175  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176  Value posBuffer, Value crdBuffer)
177  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178 
179  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180  ValueRange parentPos, Value inPadZone) const override {
181  assert(parentPos.size() == 1 &&
182  "loose-compressed level must be the first non-unique level.");
183  assert(!inPadZone && "Not implemented");
184  SmallVector<Value> memCrd(batchPrefix);
185  Value p = parentPos.front();
186  p = MULI(p, C_IDX(2));
187  memCrd.push_back(p);
188  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189  memCrd.back() = ADDI(p, C_IDX(1));
190  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191  return {pLo, pHi};
192  }
193 }; // namespace
194 
195 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196 public:
197  SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198  Value crdBuffer)
199  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200 
201  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202  ValueRange parentPos, Value inPadZone) const override {
203  assert(parentPos.size() == 1 || parentPos.size() == 2);
204  assert(!inPadZone && "Not implemented");
205  Value p = parentPos.front();
206  Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207 
208  if (segHi == nullptr)
209  return {p, ADDI(p, C_IDX(1))};
210  // Use the segHi as the loop upper bound.
211  return {p, segHi};
212  }
213 
214  ValuePair
215  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216  std::pair<Value, Value> parentRange) const override {
217  // Singleton level keeps the same range after collapsing.
218  return parentRange;
219  };
220 };
221 
222 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223 public:
224  NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225  Value crdBuffer)
226  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227 
228  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229  ValueRange parentPos, Value inPadZone) const override {
230  assert(parentPos.size() == 1 && isUnique() &&
231  "n:m level can not be non-unique.");
232  assert(!inPadZone && "Not implemented");
233  // Each n:m blk has exactly n specified elements.
234  auto n = getN(lt);
235  Value posLo = MULI(parentPos.front(), C_IDX(n));
236  return {posLo, ADDI(posLo, C_IDX(n))};
237  }
238 };
239 
240 } // namespace
241 
242 //===----------------------------------------------------------------------===//
243 // File local helpers
244 //===----------------------------------------------------------------------===//
245 
247  OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
249  builder) {
250  TypeRange ifRetTypes = elseRet.getTypes();
251  auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
252 
253  b.setInsertionPointToStart(ifOp.thenBlock());
254  Value crd = it.deref(b, l);
255  scf::ValueVector ret = builder(b, l, crd);
256  YIELD(ret);
257 
258  b.setInsertionPointToStart(ifOp.elseBlock());
259  YIELD(elseRet);
260 
261  b.setInsertionPointAfter(ifOp);
262  return ifOp.getResults();
263 }
264 
265 /// Generates code to compute the *absolute* offset of the slice based on the
266 /// provide minimum coordinates in the slice.
267 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269 /// offset is the offset computed relative to the initial tensors T.
270 ///
271 /// When isNonEmpty == true, the computed offset is meaningless and should not
272 /// be used during runtime, the method generates code to return 0 currently in
273 /// that case.
274 ///
275 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
277  Value size) {
278  Value geSize = CMPI(uge, minCrd, size);
279  // Compute minCrd - size + 1.
280  Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281  // This is the absolute offset related to the actual tensor.
282  return SELECT(geSize, mms, C_IDX(0));
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // SparseIterator derived classes.
287 //===----------------------------------------------------------------------===//
288 
289 namespace {
290 
291 // The iterator that traverses a concrete sparse tensor levels. High-level
292 // abstract iterators wrap it to achieve more complex goals (such as collapsing
293 // several levels). It also holds the common storage to hold the mlir::Values
294 // for itself as well as for wrappers.
295 class ConcreteIterator : public SparseIterator {
296 protected:
297  ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298  unsigned cursorValCnt)
299  : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300  stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301  assert(getCursor().size() == cursorValCnt);
302  };
303 
304 public:
305  // For LLVM-style RTTI.
306  static bool classof(const SparseIterator *from) {
307  return from->kind == IterKind::kTrivial;
308  }
309 
310  bool isBatchIterator() const override {
311  return stl.getLT().isa<LevelFormat::Batch>();
312  }
313  bool randomAccessible() const override {
314  return stl.getLT().hasDenseSemantic();
315  };
316  bool iteratableByFor() const override { return kind != IterKind::kDedup; };
317  Value upperBound(OpBuilder &b, Location l) const override {
318  return stl.getSize();
319  };
320 
321 protected:
322  const SparseTensorLevel &stl;
323  // Owner of the storage, all wrappers build on top of a concrete iterator
324  // share the same storage such that the iterator values are always
325  // synchronized.
326  SmallVector<Value> cursorValsStorage;
327 };
328 
329 class TrivialIterator : public ConcreteIterator {
330 public:
331  TrivialIterator(const SparseTensorLevel &stl)
332  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333 
334  TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335  Value posLo, Value posHi)
336  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337  posHi(posHi) {
338  seek(posLo);
339  }
340 
341  std::string getDebugInterfacePrefix() const override {
342  return std::string("trivial<") + stl.toString() + ">";
343  }
344  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345  return {b.getIndexType()};
346  }
347 
348  SmallVector<Value> serialize() const override {
349  SmallVector<Value> ret;
350  ret.push_back(getItPos());
351  if (randomAccessible()) {
352  // Loop high is implicit (defined by `upperBound()`) for random-access
353  // iterator, but we need to memorize posLo for linearization.
354  ret.push_back(posLo);
355  } else {
356  ret.push_back(posHi);
357  }
358  return ret;
359  };
360 
361  void deserialize(ValueRange vs) override {
362  assert(vs.size() == 2);
363  seek(vs.front());
364  if (randomAccessible())
365  posLo = vs.back();
366  else
367  posHi = vs.back();
368  };
369 
370  void genInitImpl(OpBuilder &b, Location l,
371  const SparseIterator *parent) override;
372 
373  ValuePair genForCond(OpBuilder &b, Location l) override {
374  if (randomAccessible())
375  return {deref(b, l), upperBound(b, l)};
376  return std::make_pair(getItPos(), posHi);
377  }
378 
379  Value genNotEndImpl(OpBuilder &b, Location l) override {
380  // We used the first level bound as the bound the collapsed set of levels.
381  return CMPI(ult, getItPos(), posHi);
382  }
383 
384  Value derefImpl(OpBuilder &b, Location l) override {
385  if (randomAccessible()) {
386  updateCrd(SUBI(getItPos(), posLo));
387  } else {
388  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
389  }
390  return getCrd();
391  };
392 
393  ValueRange forwardImpl(OpBuilder &b, Location l) override {
394  seek(ADDI(getItPos(), C_IDX(1)));
395  return getCursor();
396  }
397 
398  ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399  Value curPos = getCursor().front();
400  Value nxPos = forward(b, l).front();
401  seek(SELECT(cond, nxPos, curPos));
402  return getCursor();
403  }
404 
405  void locateImpl(OpBuilder &b, Location l, Value crd) override {
406  assert(randomAccessible());
407  // Seek to the linearized position.
408  seek(ADDI(crd, posLo));
409  updateCrd(crd);
410  if (isBatchIterator()) {
411  // If this is a batch iterator, also update the batch coordinate.
412  assert(batchCrds.size() > lvl);
413  batchCrds[lvl] = crd;
414  }
415  }
416 
417  Value getItPos() const { return getCursor().front(); }
418  Value posLo, posHi;
419 };
420 
421 class DedupIterator : public ConcreteIterator {
422 private:
423  Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424 
425 public:
426  DedupIterator(const SparseTensorLevel &stl)
427  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428  assert(!stl.isUnique());
429  }
430 
431  DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432  Value posLo, Value posHi)
433  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434  assert(!stl.isUnique());
435  seek({posLo, genSegmentHigh(b, l, posLo)});
436  }
437 
438  // For LLVM-style RTTI.
439  static bool classof(const SparseIterator *from) {
440  return from->kind == IterKind::kDedup;
441  }
442 
443  std::string getDebugInterfacePrefix() const override {
444  return std::string("dedup<") + stl.toString() + ">";
445  }
446  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447  return {b.getIndexType(), b.getIndexType()};
448  }
449 
450  void genInitImpl(OpBuilder &b, Location l,
451  const SparseIterator *parent) override {
452  Value c0 = C_IDX(0);
453  ValueRange pPos = c0;
454 
455  // If the parent iterator is a batch iterator, we also start from 0 (but
456  // on a different batch).
457  if (parent && !parent->isBatchIterator())
458  pPos = parent->getCurPosition();
459 
460  Value posLo;
461  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
463 
464  seek({posLo, genSegmentHigh(b, l, posLo)});
465  }
466 
467  SmallVector<Value> serialize() const override {
468  SmallVector<Value> ret;
469  ret.append(getCursor().begin(), getCursor().end());
470  ret.push_back(posHi);
471  return ret;
472  };
473  void deserialize(ValueRange vs) override {
474  assert(vs.size() == 3);
475  seek(vs.take_front(getCursor().size()));
476  posHi = vs.back();
477  };
478 
479  Value genNotEndImpl(OpBuilder &b, Location l) override {
480  return CMPI(ult, getPos(), posHi);
481  }
482 
483  Value derefImpl(OpBuilder &b, Location l) override {
484  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
485  return getCrd();
486  };
487 
488  ValueRange forwardImpl(OpBuilder &b, Location l) override {
489  Value nxPos = getSegHi(); // forward the position to the next segment.
490  seek({nxPos, genSegmentHigh(b, l, nxPos)});
491  return getCursor();
492  }
493 
494  Value getPos() const { return getCursor()[0]; }
495  Value getSegHi() const { return getCursor()[1]; }
496 
497  Value posHi;
498 };
499 
500 // A util base-iterator that delegates all methods to the wrapped iterator.
501 class SimpleWrapIterator : public SparseIterator {
502 public:
503  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504  unsigned extraCursorVal = 0)
505  : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506 
507  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
508  return wrap->getCursorValTypes(b);
509  }
510  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
511  bool randomAccessible() const override { return wrap->randomAccessible(); };
512  bool iteratableByFor() const override { return wrap->iteratableByFor(); };
513 
514  SmallVector<Value> serialize() const override { return wrap->serialize(); };
515  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
516  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
517  void genInitImpl(OpBuilder &b, Location l,
518  const SparseIterator *parent) override {
519  wrap->genInit(b, l, parent);
520  }
521  Value genNotEndImpl(OpBuilder &b, Location l) override {
522  return wrap->genNotEndImpl(b, l);
523  }
524  ValueRange forwardImpl(OpBuilder &b, Location l) override {
525  return wrap->forward(b, l);
526  };
527  Value upperBound(OpBuilder &b, Location l) const override {
528  return wrap->upperBound(b, l);
529  };
530 
531  Value derefImpl(OpBuilder &b, Location l) override {
532  return wrap->derefImpl(b, l);
533  }
534 
535  void locateImpl(OpBuilder &b, Location l, Value crd) override {
536  return wrap->locate(b, l, crd);
537  }
538 
539  SparseIterator &getWrappedIterator() const { return *wrap; }
540 
541 protected:
542  std::unique_ptr<SparseIterator> wrap;
543 };
544 
545 //
546 // A filter iterator wrapped from another iterator. The filter iterator update
547 // the wrapped iterator *in-place*.
548 //
549 class FilterIterator : public SimpleWrapIterator {
550  // Coorindate translation between crd loaded from the wrap iterator and the
551  // filter iterator.
552  Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
553  // crd = (wrapCrd - offset) / stride
554  return DIVUI(SUBI(wrapCrd, offset), stride);
555  }
556  Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
557  // wrapCrd = crd * stride + offset
558  return ADDI(MULI(crd, stride), offset);
559  }
560 
561  Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
562 
563  Value genShouldFilter(OpBuilder &b, Location l);
564 
565 public:
566  // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
567  // when crd always < size.
568  FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
569  Value stride, Value size)
570  : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
571  stride(stride), size(size) {}
572 
573  // For LLVM-style RTTI.
574  static bool classof(const SparseIterator *from) {
575  return from->kind == IterKind::kFilter;
576  }
577 
578  std::string getDebugInterfacePrefix() const override {
579  return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
580  }
581 
582  bool iteratableByFor() const override { return randomAccessible(); };
583  Value upperBound(OpBuilder &b, Location l) const override { return size; };
584 
585  void genInitImpl(OpBuilder &b, Location l,
586  const SparseIterator *parent) override {
587  wrap->genInit(b, l, parent);
588  if (!randomAccessible()) {
589  // TODO: we can skip this when stride == 1 and offset == 0, we can also
590  // use binary search here.
591  forwardIf(b, l, genShouldFilter(b, l));
592  } else {
593  // Else, locate to the slice.offset, which is the first coordinate
594  // included by the slice.
595  wrap->locate(b, l, offset);
596  }
597  }
598 
599  Value genNotEndImpl(OpBuilder &b, Location l) override;
600 
601  Value derefImpl(OpBuilder &b, Location l) override {
602  updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
603  return getCrd();
604  }
605 
606  void locateImpl(OpBuilder &b, Location l, Value crd) override {
607  assert(randomAccessible());
608  wrap->locate(b, l, toWrapCrd(b, l, crd));
609  updateCrd(crd);
610  }
611 
612  ValueRange forwardImpl(OpBuilder &b, Location l) override;
613 
614  Value offset, stride, size;
615 };
616 
617 //
618 // A pad iterator wrapped from another iterator. The pad iterator updates
619 // the wrapped iterator *in-place*.
620 //
621 class PadIterator : public SimpleWrapIterator {
622 
623 public:
624  PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
625  Value padHigh)
626  : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
627  wrap->randomAccessible() ? 1 : 0),
628  padLow(padLow), padHigh(padHigh) {}
629 
630  // For LLVM-style RTTI.
631  static bool classof(const SparseIterator *from) {
632  return from->kind == IterKind::kPad;
633  }
634 
635  std::string getDebugInterfacePrefix() const override {
636  return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
637  }
638 
639  // Returns a pair of values for *upper*, *lower* bound respectively.
640  ValuePair genForCond(OpBuilder &b, Location l) override {
641  if (randomAccessible())
642  return {getCrd(), upperBound(b, l)};
643  return wrap->genForCond(b, l);
644  }
645 
646  // For padded dense iterator, we append a `inPadZone: bool` in addition to
647  // values used by the wrapped iterator.
648  ValueRange getCurPosition() const override { return getCursor(); }
649 
650  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
651  SmallVector<Type> ret = wrap->getCursorValTypes(b);
652  // Need an extra boolean value `inPadZone` for padded dense iterator.
653  if (randomAccessible())
654  ret.push_back(b.getI1Type());
655 
656  return ret;
657  }
658 
659  // The upper bound after padding becomes `size + padLow + padHigh`.
660  Value upperBound(OpBuilder &b, Location l) const override {
661  return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
662  };
663 
664  // The pad_coord = coord + pad_lo
665  Value derefImpl(OpBuilder &b, Location l) override {
666  updateCrd(ADDI(wrap->deref(b, l), padLow));
667  return getCrd();
668  }
669 
670  void locateImpl(OpBuilder &b, Location l, Value crd) override {
671  assert(randomAccessible());
672  wrap->locate(b, l, SUBI(crd, padLow));
673 
674  // inPadZone = crd < padLow || crd >= size + padLow.
675  Value inPadLow = CMPI(ult, crd, padLow);
676  Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
677  getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
678 
679  updateCrd(crd);
680  }
681 
682  Value padLow, padHigh;
683 };
684 
685 class NonEmptySubSectIterator : public SparseIterator {
686 public:
687  using TraverseBuilder = llvm::function_ref<scf::ValueVector(
689 
690  NonEmptySubSectIterator(OpBuilder &b, Location l,
691  const SparseIterator *parent,
692  std::unique_ptr<SparseIterator> &&delegate,
693  Value subSectSz)
694  : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
695  parent(parent), delegate(std::move(delegate)),
696  tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
697  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
698  if (p == nullptr) {
699  // Extract subsections along the root level.
700  maxTupleCnt = C_IDX(1);
701  } else if (p->lvl == lvl) {
702  // Extract subsections along the same level.
703  maxTupleCnt = p->maxTupleCnt;
704  assert(false && "Not implemented.");
705  } else {
706  // Extract subsections along the previous level.
707  assert(p->lvl + 1 == lvl);
708  maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
709  }
710  // We don't need an extra buffer to find subsections on random-accessible
711  // levels.
712  if (randomAccessible())
713  return;
714  subSectPosBuf = allocSubSectPosBuf(b, l);
715  }
716 
717  // For LLVM-style RTTI.
718  static bool classof(const SparseIterator *from) {
719  return from->kind == IterKind::kNonEmptySubSect;
720  }
721 
722  std::string getDebugInterfacePrefix() const override {
723  return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
724  }
725  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
726  // minCrd, absolute offset, notEnd
727  return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
728  }
729 
730  // The sliced pointer buffer is organized as:
731  // [[itVal0, itVal1, ..., pNx0],
732  // [itVal0, itVal1, ..., pNx0],
733  // ...]
734  Value allocSubSectPosBuf(OpBuilder &b, Location l) {
735  return b.create<memref::AllocaOp>(
736  l,
737  MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
738  maxTupleCnt);
739  }
740 
741  void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
742  Value start) const {
743  b.create<memref::StoreOp>(l, start, subSectPosBuf,
744  ValueRange{tupleId, C_IDX(tupleSz)});
745  }
746 
747  Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
748  return b.create<memref::LoadOp>(l, subSectPosBuf,
749  ValueRange{tupleId, C_IDX(tupleSz)});
750  }
751 
752  void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
753  ValueRange itVals) const {
754  assert(itVals.size() == tupleSz);
755  for (unsigned i = 0; i < tupleSz; i++) {
756  b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
757  ValueRange{tupleId, C_IDX(i)});
758  }
759  }
760 
761  SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
762  Value tupleId) const {
763  SmallVector<Value> ret;
764  for (unsigned i = 0; i < tupleSz; i++) {
765  Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
766  ValueRange{tupleId, C_IDX(i)});
767  ret.push_back(v);
768  }
769  return ret;
770  }
771 
772  bool isSubSectRoot() const {
773  return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
774  }
775 
776  // Generate code that inflate the current subsection tree till the current
777  // level such that every leaf node is visited.
778  ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
779  TraverseBuilder builder) const;
780 
781  bool isBatchIterator() const override { return delegate->isBatchIterator(); }
782  bool randomAccessible() const override {
783  return delegate->randomAccessible();
784  };
785  bool iteratableByFor() const override { return randomAccessible(); };
786  Value upperBound(OpBuilder &b, Location l) const override {
787  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
788  Value parentUB =
789  p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
790  return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
791  };
792 
793  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
794 
795  void locateImpl(OpBuilder &b, Location l, Value crd) override {
796  Value absOff = crd;
797 
798  if (isSubSectRoot())
799  delegate->locate(b, l, absOff);
800  else
801  assert(parent->lvl + 1 == lvl);
802 
803  seek(ValueRange{absOff, absOff, C_TRUE});
804  updateCrd(crd);
805  }
806 
807  Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
808  return SUBI(wrapCrd, getAbsOff());
809  }
810 
811  Value genNotEndImpl(OpBuilder &b, Location l) override {
812  return getNotEnd();
813  };
814 
815  Value derefImpl(OpBuilder &b, Location l) override {
816  // Use the relative offset to coiterate.
817  Value crd;
818  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
819  if (p && p->lvl == lvl)
820  crd = SUBI(getAbsOff(), p->getAbsOff());
821  crd = getAbsOff();
822 
823  updateCrd(crd);
824  return crd;
825  };
826 
827  ValueRange forwardImpl(OpBuilder &b, Location l) override;
828 
829  Value getMinCrd() const { return subSectMeta[0]; }
830  Value getAbsOff() const { return subSectMeta[1]; }
831  Value getNotEnd() const { return subSectMeta[2]; }
832 
833  const SparseIterator *parent;
834  std::unique_ptr<SparseIterator> delegate;
835 
836  // Number of values required to serialize the wrapped iterator.
837  const unsigned tupleSz;
838  // Max number of tuples, and the actual number of tuple.
839  Value maxTupleCnt, tupleCnt;
840  // The memory used to cache the tuple serialized from the wrapped iterator.
841  Value subSectPosBuf;
842 
843  const Value subSectSz;
844 
845  // minCrd, absolute offset, notEnd
846  SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
847 };
848 
849 class SubSectIterator;
850 
851 // A wrapper that helps generating code to traverse a subsection, used
852 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
853 struct SubSectIterHelper {
854  explicit SubSectIterHelper(const SubSectIterator &iter);
855  explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
856 
857  // Delegate methods.
858  void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
859  void locate(OpBuilder &b, Location l, Value crd);
860  Value genNotEnd(OpBuilder &b, Location l);
861  Value deref(OpBuilder &b, Location l);
862  ValueRange forward(OpBuilder &b, Location l);
863 
864  const NonEmptySubSectIterator &subSect;
866 };
867 
868 class SubSectIterator : public SparseIterator {
869 public:
870  SubSectIterator(const NonEmptySubSectIterator &subSect,
871  const SparseIterator &parent,
872  std::unique_ptr<SparseIterator> &&wrap)
874  /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
875  subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
876  assert(subSect.tid == tid && subSect.lvl == lvl);
877  assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
878  };
879 
880  // For LLVM-style RTTI.
881  static bool classof(const SparseIterator *from) {
882  return from->kind == IterKind::kSubSect;
883  }
884 
885  std::string getDebugInterfacePrefix() const override {
886  return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
887  }
888  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
889  SmallVector<Type> ret = wrap->getCursorValTypes(b);
890  if (!randomAccessible())
891  ret.push_back(b.getIndexType()); // The extra counter.
892  return ret;
893  }
894 
895  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
896  bool randomAccessible() const override { return wrap->randomAccessible(); };
897  bool iteratableByFor() const override { return randomAccessible(); };
898  Value upperBound(OpBuilder &b, Location l) const override {
899  return subSect.subSectSz;
900  }
901 
902  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
903 
904  Value getNxLvlTupleId(OpBuilder &b, Location l) const {
905  if (randomAccessible()) {
906  return ADDI(getCrd(), nxLvlTupleStart);
907  };
908  return ADDI(getCursor().back(), nxLvlTupleStart);
909  }
910 
911  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
912  if (randomAccessible()) {
913  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
914  assert(p->lvl + 1 == lvl);
915  wrap->genInit(b, l, p);
916  // Linearize the dense subsection index.
917  nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
918  } else {
919  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
920  wrap->deserialize(subSect.delegate->serialize());
921  nxLvlTupleStart = C_IDX(0);
922  }
923  return;
924  }
925  assert(!randomAccessible());
926  assert(getCursor().size() == wrap->getCursor().size() + 1);
927  // Extra counter that counts the number of actually visited coordinates in
928  // the sparse subsection.
929  getMutCursorVals().back() = C_IDX(0);
930  Value tupleId;
931  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
932  assert(p->lvl + 1 == lvl);
933  tupleId = p->getNxLvlTupleId(b, l);
934  } else {
935  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
936  tupleId = C_IDX(0);
937  }
938  nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
939  helper.deserializeFromTupleId(b, l, tupleId);
940  }
941 
942  void locateImpl(OpBuilder &b, Location l, Value crd) override {
943  helper.locate(b, l, crd);
944  updateCrd(crd);
945  }
946 
947  Value genNotEndImpl(OpBuilder &b, Location l) override {
948  return helper.genNotEnd(b, l);
949  }
950 
951  Value derefImpl(OpBuilder &b, Location l) override {
952  Value crd = helper.deref(b, l);
953  updateCrd(crd);
954  return crd;
955  };
956 
957  ValueRange forwardImpl(OpBuilder &b, Location l) override {
958  helper.forward(b, l);
959  assert(!randomAccessible());
960  assert(getCursor().size() == wrap->getCursor().size() + 1);
961  getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
962  return getCursor();
963  };
964 
965  Value nxLvlTupleStart;
966 
967  const NonEmptySubSectIterator &subSect;
968  std::unique_ptr<SparseIterator> wrap;
969  const SparseIterator &parent;
970 
971  SubSectIterHelper helper;
972 };
973 
974 } // namespace
975 
976 //===----------------------------------------------------------------------===//
977 // SparseIterator derived classes implementation.
978 //===----------------------------------------------------------------------===//
979 
981  const SparseIterator *p) {
983  std::string prefix = getDebugInterfacePrefix();
984  Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
985  getCursorValTypes(b));
986  seek(begin->getResults());
987  return;
988  }
989  // Inherent batch coordinates from parents.
990  if (p)
991  inherentBatch(*p);
992  // TODO: support lowering to function call.
993  return genInitImpl(b, l, p);
994 }
995 
998  std::string prefix = getDebugInterfacePrefix();
999  Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1000  getCursor(), b.getI1Type());
1001  return notEnd->getResult(0);
1002  }
1003  // TODO: support lowering to function call.
1004  return genNotEndImpl(b, l);
1005 }
1006 
1009  std::string prefix = getDebugInterfacePrefix();
1010  SmallVector<Value> args = getCursor();
1011  args.push_back(crd);
1012  Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1013  getCursorValTypes(b));
1014  seek(locate->getResults());
1015  updateCrd(crd);
1016  return;
1017  }
1018  return locateImpl(b, l, crd);
1019 }
1020 
1023  std::string prefix = getDebugInterfacePrefix();
1024  SmallVector<Value> args = getCursor();
1025  Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1026  getCursor(), b.getIndexType());
1027  updateCrd(deref->getResult(0));
1028  return getCrd();
1029  }
1030  return derefImpl(b, l);
1031 }
1032 
1034  assert(!randomAccessible());
1036  std::string prefix = getDebugInterfacePrefix();
1037  Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1039  seek(next->getResults());
1040  return getCursor();
1041  }
1042  return forwardImpl(b, l);
1043 }
1044 
1046  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
1047  // Generate else branch first, otherwise iterator values will be updated by
1048  // `forward()`.
1049  b.setInsertionPointToStart(ifOp.elseBlock());
1050  YIELD(getCursor());
1051 
1052  b.setInsertionPointToStart(ifOp.thenBlock());
1053  YIELD(forward(b, l));
1054 
1055  b.setInsertionPointAfter(ifOp);
1056  seek(ifOp.getResults());
1057  return getCursor();
1058 }
1059 
1060 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1061  auto whileOp = b.create<scf::WhileOp>(
1062  l, pos.getType(), pos,
1063  /*beforeBuilder=*/
1064  [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1065  Value inBound = CMPI(ult, ivs.front(), posHi);
1066  auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
1067  {
1068  OpBuilder::InsertionGuard guard(b);
1069  // If in bound, load the next coordinates and check duplication.
1070  b.setInsertionPointToStart(ifInBound.thenBlock());
1071  Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1072  Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1073  Value isDup = CMPI(eq, headCrd, tailCrd);
1074  YIELD(isDup);
1075  // Else, the position is out of bound, yield false.
1076  b.setInsertionPointToStart(ifInBound.elseBlock());
1077  YIELD(constantI1(b, l, false));
1078  }
1079  b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1080  },
1081  /*afterBuilder=*/
1082  [](OpBuilder &b, Location l, ValueRange ivs) {
1083  Value nxPos = ADDI(ivs[0], C_IDX(1));
1084  YIELD(nxPos);
1085  });
1086  // Return the segment high.
1087  return whileOp.getResult(0);
1088 }
1089 
1090 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1091  Value wrapCrd) {
1092  Value crd = fromWrapCrd(b, l, wrapCrd);
1093  // Test whether the coordinate is on stride.
1094  Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1095  // Test wrapCrd < offset
1096  notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1097  // Test crd >= length
1098  notlegit = ORI(CMPI(uge, crd, size), notlegit);
1099  return notlegit;
1100 }
1101 
1102 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1103  auto r = genWhenInBound(
1104  b, l, *wrap, C_FALSE,
1105  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1106  Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1107  return {notLegit};
1108  });
1109 
1110  assert(r.size() == 1);
1111  return r.front();
1112 }
1113 
1114 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1115  assert(!wrap->randomAccessible());
1116  auto r = genWhenInBound(
1117  b, l, *wrap, C_FALSE,
1118  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1119  Value crd = fromWrapCrd(b, l, wrapCrd);
1120  // crd < size
1121  return {CMPI(ult, crd, size)};
1122  });
1123  assert(r.size() == 1);
1124  return r.front();
1125 }
1126 
1127 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1128  assert(!randomAccessible());
1129  // Generates
1130  //
1131  // bool isFirst = true;
1132  // while !it.end() && (!legit(*it) || isFirst)
1133  // wrap ++;
1134  // isFirst = false;
1135  //
1136  // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1137  // flag here because `wrap++` might have a complex implementation (e.g., to
1138  // forward a subsection).
1139  Value isFirst = constantI1(b, l, true);
1140 
1141  SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1142  whileArgs.push_back(isFirst);
1143  auto whileOp = b.create<scf::WhileOp>(
1144  l, ValueRange(whileArgs).getTypes(), whileArgs,
1145  /*beforeBuilder=*/
1146  [this](OpBuilder &b, Location l, ValueRange ivs) {
1147  ValueRange isFirst = linkNewScope(ivs);
1148  assert(isFirst.size() == 1);
1149  scf::ValueVector cont =
1150  genWhenInBound(b, l, *wrap, C_FALSE,
1151  [this, isFirst](OpBuilder &b, Location l,
1152  Value wrapCrd) -> scf::ValueVector {
1153  // crd < size && !legit();
1154  Value notLegit =
1155  genCrdNotLegitPredicate(b, l, wrapCrd);
1156  Value crd = fromWrapCrd(b, l, wrapCrd);
1157  Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1158  ret = ORI(ret, isFirst.front());
1159  return {ret};
1160  });
1161  b.create<scf::ConditionOp>(l, cont.front(), ivs);
1162  },
1163  /*afterBuilder=*/
1164  [this](OpBuilder &b, Location l, ValueRange ivs) {
1165  linkNewScope(ivs);
1166  wrap->forward(b, l);
1167  SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1168  yieldVals.push_back(constantI1(b, l, false));
1169  YIELD(yieldVals);
1170  });
1171 
1172  b.setInsertionPointAfter(whileOp);
1173  linkNewScope(whileOp.getResults());
1174  return getCursor();
1175 }
1176 
1177 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1178  : subSect(subSect), wrap(*subSect.delegate) {}
1179 
1180 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1181  : subSect(iter.subSect), wrap(*iter.wrap) {}
1182 
1183 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1184  Value tupleId) {
1185  assert(!subSect.randomAccessible());
1186  wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1187 }
1188 
1189 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1190  Value absCrd = ADDI(crd, subSect.getAbsOff());
1191  wrap.locate(b, l, absCrd);
1192 }
1193 
1194 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1195  assert(!wrap.randomAccessible());
1196  auto r = genWhenInBound(
1197  b, l, wrap, C_FALSE,
1198  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1199  Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1200  // crd < size
1201  return {CMPI(ult, crd, subSect.subSectSz)};
1202  });
1203  assert(r.size() == 1);
1204  return r.front();
1205 }
1206 
1207 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1208  Value wrapCrd = wrap.deref(b, l);
1209  Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1210  return crd;
1211 }
1212 
1213 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1214  return wrap.forward(b, l);
1215 }
1216 
1217 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1218  OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1219  // Set up the helper to help traverse a sparse subsection.
1220  SubSectIterHelper helper(*this);
1221  if (!randomAccessible()) {
1222  // The subsection tree have been expanded till the level and cached,
1223  // traverse all the leaves and expanded to the next level.
1224  SmallVector<Value> iterArgs;
1225  iterArgs.push_back(C_IDX(0));
1226  iterArgs.append(reduc.begin(), reduc.end());
1227  auto forEachLeaf = b.create<scf::ForOp>(
1228  l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1229  [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1230  ValueRange iterArgs) {
1231  // Deserialize the iterator at the cached position (tupleId).
1232  helper.deserializeFromTupleId(b, l, tupleId);
1233 
1234  Value cnt = iterArgs.front();
1235  // Record the number of leaf nodes included in the subsection.
1236  // The number indicates the starting tupleId for the next level that
1237  // is corresponding to the current node.
1238  helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1239 
1240  SmallVector<Value> whileArgs(helper.wrap.getCursor());
1241  whileArgs.append(iterArgs.begin(), iterArgs.end());
1242 
1243  auto whileOp = b.create<scf::WhileOp>(
1244  l, ValueRange(whileArgs).getTypes(), whileArgs,
1245  /*beforeBuilder=*/
1246  [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1247  helper.wrap.linkNewScope(ivs);
1248  b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1249  },
1250  /*afterBuilder=*/
1251  [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1252  ValueRange remIter = helper.wrap.linkNewScope(ivs);
1253  Value cnt = remIter.front();
1254  ValueRange userIter = remIter.drop_front();
1255  scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1256 
1257  SmallVector<Value> nxIter = helper.forward(b, l);
1258  nxIter.push_back(ADDI(cnt, C_IDX(1)));
1259  nxIter.append(userNx.begin(), userNx.end());
1260  YIELD(nxIter);
1261  });
1262  ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1263  YIELD(res);
1264  });
1265  return forEachLeaf.getResults().drop_front();
1266  }
1267 
1268  assert(randomAccessible());
1269  // Helper lambda that traverse the current dense subsection range.
1270  auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1271  const SparseIterator *parent,
1272  ValueRange reduc) {
1273  assert(!parent || parent->lvl + 1 == lvl);
1274  delegate->genInit(b, l, parent);
1275  auto forOp = b.create<scf::ForOp>(
1276  l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1277  [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1278  helper.locate(b, l, crd);
1279  scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1280  YIELD(nx);
1281  });
1282  return forOp.getResults();
1283  };
1284 
1285  if (isSubSectRoot()) {
1286  return visitDenseSubSect(b, l, parent, reduc);
1287  }
1288  // Else, this is not the root, recurse until root.
1289  auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1290  assert(p->lvl + 1 == lvl);
1291  return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1292 }
1293 
1294 void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1295  const SparseIterator *parent) {
1296 
1297  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1298  batchCrds.resize(stl.lvl + 1, nullptr);
1299 
1300  Value c0 = C_IDX(0);
1301  ValueRange pPos = c0;
1302  Value inPadZone = nullptr;
1303  // If the parent iterator is a batch iterator, we also start from 0 (but
1304  // on a different batch).
1305  if (parent && !parent->isBatchIterator()) {
1306  pPos = parent->getCurPosition();
1307  if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1308  // A padded dense iterator create "sparse" padded zone, which need to be
1309  // handled specially.
1310  inPadZone = pPos.back();
1311  pPos = pPos.drop_back();
1312  }
1313  }
1314 
1315  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1316  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1317  // Seek to the lowest position.
1318  seek(posLo);
1319 }
1320 
1321 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1322  const SparseIterator *) {
1323  Value c0 = C_IDX(0);
1324  if (!isSubSectRoot()) {
1325  assert(parent->lvl + 1 == lvl);
1326  if (randomAccessible()) {
1327  // We can not call wrap->genInit() here to initialize the wrapped
1328  // iterator, because the parent of the curent iterator is still
1329  // unresolved.
1330  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1331  return;
1332  }
1333 
1334  auto *p = cast<NonEmptySubSectIterator>(parent);
1335  SmallVector<Value, 3> reduc = {
1336  C_IDX(-1), // minCrd (max signless integer)
1337  c0, // tupleId
1338  };
1339 
1340  // Expand the subsection tree from the parent level to the current level.
1341  ValueRange result = p->inflateSubSectTree(
1342  b, l, reduc,
1343  [this](OpBuilder &b, Location l, const SparseIterator *parent,
1344  ValueRange reduc) -> scf::ValueVector {
1345  assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1346  Value minCrd = reduc.front();
1347  Value tupleId = reduc.back();
1348 
1349  // Initialize the subsection range.
1350  SubSectIterHelper helper(*this);
1351  helper.wrap.genInit(b, l, parent);
1352 
1353  // Update minCrd.
1354  minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1355  [minCrd](OpBuilder &b, Location l,
1356  Value crd) -> scf::ValueVector {
1357  Value min = MINUI(crd, minCrd);
1358  return {min};
1359  })
1360  .front();
1361 
1362  // Cache the sparse range.
1363  storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1364  tupleId = ADDI(tupleId, C_IDX(1));
1365  return {minCrd, tupleId};
1366  });
1367  assert(result.size() == 2);
1368  tupleCnt = result.back();
1369 
1370  Value minCrd = result.front();
1371  Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1372  Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1373  seek({minCrd, absOff, notEnd});
1374  return;
1375  }
1376 
1377  // This is the root level of the subsection, which means that it is resolved
1378  // to one node.
1379  assert(isSubSectRoot());
1380 
1381  // Initialize the position, the position marks the *lower bound* of the
1382  // subRange. The higher bound is determined by the size of the subsection.
1383  delegate->genInit(b, l, parent);
1384  if (randomAccessible()) {
1385  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1386  return;
1387  }
1388 
1389  // Only have one root node.
1390  tupleCnt = C_IDX(1);
1391  // Cache the sparse range.
1392  storeCursorVals(b, l, c0, delegate->serialize());
1393  SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1394  auto meta = genWhenInBound(
1395  b, l, *delegate, elseRet,
1396  [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1397  Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1398  return {crd, offset, C_TRUE};
1399  });
1400 
1401  seek(meta);
1402 }
1403 
1404 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1405  assert(!randomAccessible());
1406  Value c0 = C_IDX(0), c1 = C_IDX(1);
1407  // Forward to the next non empty slice by generating
1408  //
1409  // if (minCrd > offset) {
1410  // offset += 1
1411  // } else {
1412  // minCrd = nextMinInSlice();
1413  // offset = minCrd - size + 1;
1414  // }
1415  //
1416  // if (offset + size > parents.size)
1417  // isNonEmpty = false;
1418  Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1419  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1420  {
1421  OpBuilder::InsertionGuard guard(b);
1422  // Take the fast path
1423  // if (minCrd > offset)
1424  // offset += 1
1425  b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1426  Value nxOffset = ADDI(getAbsOff(), c1);
1427  YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1428 
1429  // else /*minCrd == offset*/ {
1430  // for (i = 0; i < tupleCnt; i++) {
1431  // wrap->deserialize(pos[i]);
1432  // minCrd=min(minCrd, *wrap);
1433  // }
1434  // offset = minCrd - size + 1;
1435  // }
1436  b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1437  SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1438  C_FALSE}; // isNotEnd
1439  auto loopNest = scf::buildLoopNest(
1440  b, l, c0, tupleCnt, c1, loopArgs,
1441  [this](OpBuilder &b, Location l, ValueRange ivs,
1442  ValueRange iterArgs) -> scf::ValueVector {
1443  Value tupleId = ivs.front();
1444  SubSectIterHelper helper(*this);
1445  helper.deserializeFromTupleId(b, l, tupleId);
1446 
1447  return genWhenInBound(
1448  b, l, *delegate, /*elseRet=*/iterArgs,
1449  [this, iterArgs, tupleId](OpBuilder &b, Location l,
1450  Value crd) -> scf::ValueVector {
1451  // if coord == minCrd
1452  // wrap->forward();
1453  Value isMin = CMPI(eq, crd, getMinCrd());
1454  delegate->forwardIf(b, l, isMin);
1455  // Update the forwarded iterator values if needed.
1456  auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1457  b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1458  storeCursorVals(b, l, tupleId, delegate->serialize());
1459  b.setInsertionPointAfter(ifIsMin);
1460  // if (!wrap.end())
1461  // yield(min(nxMinCrd, *wrap), true)
1462  Value nxMin = iterArgs[0];
1463  return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1464  [nxMin](OpBuilder &b, Location l,
1465  Value crd) -> scf::ValueVector {
1466  Value nx = b.create<arith::MinUIOp>(
1467  l, crd, nxMin);
1468  return {nx, C_TRUE};
1469  });
1470  });
1471  });
1472 
1473  scf::ForOp forOp = loopNest.loops.front();
1474  b.setInsertionPointAfter(forOp);
1475 
1476  Value nxMinCrd = forOp.getResult(0);
1477  Value nxNotEnd = forOp.getResult(1);
1478  Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1479  YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1480  }
1481 
1482  Value nxMinCrd = ifOp.getResult(0);
1483  Value nxAbsOff = ifOp.getResult(1);
1484  Value nxNotEnd = ifOp.getResult(2);
1485 
1486  // We should at least forward the offset by one.
1487  Value minAbsOff = ADDI(getAbsOff(), c1);
1488  nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1489 
1490  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1491  // The coordinate should not exceeds the space upper bound.
1492  Value crd = deref(b, l);
1493  nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1494 
1495  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1496  return getCursor();
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 // SparseIterationSpace Implementation
1501 //===----------------------------------------------------------------------===//
1502 
1504  Location l, OpBuilder &b, Value t, unsigned tid,
1505  std::pair<Level, Level> lvlRange, ValueRange parentPos)
1506  : lvls() {
1507  auto [lvlLo, lvlHi] = lvlRange;
1508 
1509  Value c0 = C_IDX(0);
1510  if (parentPos.empty())
1511  parentPos = c0;
1512 
1513  for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1514  lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1515 
1516  bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1517  for (auto &lvl : getLvlRef().drop_front())
1518  bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1519 }
1520 
1522  IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1523  // Reconstruct every sparse tensor level.
1524  SparseIterationSpace space;
1525  for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1526  unsigned bufferCnt = 0;
1527  if (lt.isWithPosLT())
1528  bufferCnt++;
1529  if (lt.isWithCrdLT())
1530  bufferCnt++;
1531  // Sparse tensor buffers.
1532  ValueRange buffers = values.take_front(bufferCnt);
1533  values = values.drop_front(bufferCnt);
1534 
1535  // Level size.
1536  Value sz = values.front();
1537  values = values.drop_front();
1538  space.lvls.push_back(
1539  makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1540  }
1541  // Two bounds.
1542  space.bound = std::make_pair(values[0], values[1]);
1543  values = values.drop_front(2);
1544 
1545  // Must have consumed all values.
1546  assert(values.empty());
1547  return space;
1548 }
1549 
1550 std::unique_ptr<SparseIterator>
1552  return makeSimpleIterator(b, l, *this);
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // SparseIterator factory functions.
1557 //===----------------------------------------------------------------------===//
1558 
1559 /// Helper function to create a TensorLevel object from given `tensor`.
1560 std::unique_ptr<SparseTensorLevel>
1562  unsigned t, Level l) {
1563  assert(lt.getNumBuffer() == b.size());
1564  switch (lt.getLvlFmt()) {
1565  case LevelFormat::Dense:
1566  return std::make_unique<DenseLevel>(t, l, sz);
1567  case LevelFormat::Batch:
1568  return std::make_unique<BatchLevel>(t, l, sz);
1570  return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1572  return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1574  return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1575  case LevelFormat::NOutOfM:
1576  return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1577  case LevelFormat::Undef:
1578  llvm_unreachable("undefined level format");
1579  }
1580  llvm_unreachable("unrecognizable level format");
1581 }
1582 
1583 std::unique_ptr<SparseTensorLevel>
1585  unsigned tid, Level lvl) {
1586  auto stt = getSparseTensorType(t);
1587 
1588  LevelType lt = stt.getLvlType(lvl);
1589  Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1590  : b.create<tensor::DimOp>(l, t, lvl).getResult();
1591 
1592  SmallVector<Value, 2> buffers;
1593  if (lt.isWithPosLT()) {
1594  Value pos = b.create<ToPositionsOp>(l, t, lvl);
1595  buffers.push_back(pos);
1596  }
1597  if (lt.isWithCrdLT()) {
1598  Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1599  buffers.push_back(pos);
1600  }
1601  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
1602 }
1603 
1604 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1605 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1606  SparseEmitStrategy strategy) {
1607  auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1608  auto it = std::make_unique<TrivialIterator>(*stl);
1609  it->setSparseEmitStrategy(strategy);
1610  return std::make_pair(std::move(stl), std::move(it));
1611 }
1612 
1613 std::unique_ptr<SparseIterator>
1615  const SparseIterationSpace &iterSpace) {
1616  // assert(iterSpace.getSpaceDim() == 1);
1617  std::unique_ptr<SparseIterator> ret;
1618  if (!iterSpace.isUnique()) {
1619  // We always dedupliate the non-unique level, but we should optimize it away
1620  // if possible.
1621  ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1622  iterSpace.getBoundLo(),
1623  iterSpace.getBoundHi());
1624  } else {
1625  ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1626  iterSpace.getBoundLo(),
1627  iterSpace.getBoundHi());
1628  }
1629  ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1630  return ret;
1631 }
1632 
1633 std::unique_ptr<SparseIterator>
1635  SparseEmitStrategy strategy) {
1636  std::unique_ptr<SparseIterator> ret;
1637  if (!isUniqueLT(stl.getLT())) {
1638  // We always dedupliate the non-unique level, but we should optimize it away
1639  // if possible.
1640  ret = std::make_unique<DedupIterator>(stl);
1641  } else {
1642  ret = std::make_unique<TrivialIterator>(stl);
1643  }
1644  ret->setSparseEmitStrategy(strategy);
1645  return ret;
1646 }
1647 
1648 std::unique_ptr<SparseIterator>
1649 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1650  Value offset, Value stride, Value size,
1651  SparseEmitStrategy strategy) {
1652 
1653  auto ret =
1654  std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1655  ret->setSparseEmitStrategy(strategy);
1656  return ret;
1657 }
1658 
1659 std::unique_ptr<SparseIterator>
1660 sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1661  Value padLow, Value padHigh,
1662  SparseEmitStrategy strategy) {
1663  auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1664  ret->setSparseEmitStrategy(strategy);
1665  return ret;
1666 }
1667 
1669  auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1670  if (filter)
1671  return &filter->getWrappedIterator();
1672  return it;
1673 }
1674 
1675 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1676  OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1677  std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1678  SparseEmitStrategy strategy) {
1679 
1680  // Try unwrap the NonEmptySubSectIterator from a filter parent.
1681  parent = tryUnwrapFilter(parent);
1682  std::unique_ptr<SparseIterator> it =
1683  std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1684  std::move(delegate), size);
1685 
1686  if (stride != 1) {
1687  // TODO: We can safely skip bound checking on sparse levels, but for dense
1688  // iteration space, we need the bound to infer the dense loop range.
1689  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1690  C_IDX(stride), /*size=*/loopBound);
1691  }
1692  it->setSparseEmitStrategy(strategy);
1693  return it;
1694 }
1695 
1696 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1697  OpBuilder &b, Location l, const SparseIterator &subSectIter,
1698  const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1699  Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1700 
1701  // This must be a subsection iterator or a filtered subsection iterator.
1702  auto &subSect =
1703  llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1704 
1705  std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1706  subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1707 
1708  if (stride != 1) {
1709  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1710  C_IDX(stride), /*size=*/loopBound);
1711  }
1712  it->setSparseEmitStrategy(strategy);
1713  return it;
1714 }
1715 
1716 #undef CMPI
1717 #undef C_IDX
1718 #undef YIELD
1719 #undef ADDI
1720 #undef ANDI
1721 #undef SUBI
1722 #undef MULI
1723 #undef SELECT
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
#define SELECT(c, lhs, rhs)
#define C_FALSE
#define SUBI(lhs, rhs)
#define MULI(lhs, rhs)
#define C_IDX(v)
#define ANDI(lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
#define C_TRUE
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
#define YIELD(vs)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
#define ORI(lhs, rhs)
#define DIVUI(lhs, rhs)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
#define ADDI(lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
result_range getResults()
Definition: Operation.h:410
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
const SparseTensorLevel & getLastLvl() const
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool isUniqueLT(LevelType lt)
Definition: Enums.h:428
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
uint64_t getN(LevelType lt)
Definition: Enums.h:442
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context)
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:52
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LoopVector loops
Definition: SCF.h:73
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr unsigned getNumBuffer() const
Definition: Enums.h:360
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
Definition: Enums.h:320
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
Definition: Enums.h:348
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.
Definition: Enums.h:354