MLIR  21.0.0git
SparseTensorIterator.cpp
Go to the documentation of this file.
1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11 
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20 
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs) \
25  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26  .getResult())
27 
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
32 #define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (b.create<arith::OrIOp>(l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (b.create<arith::MinUIOp>(l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41  (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)).getResult())
42 
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51  // It is either an array of size 2 or size 1 depending on whether the sparse
52  // level requires a position array.
53  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54  std::array<Value, 1>>;
55 
56 public:
57  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58  BufferT buffers)
59  : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60 
61  ValueRange getLvlBuffers() const override { return buffers; }
62 
63  Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64  Value iv) const override {
65  SmallVector<Value> memCrd(batchPrefix);
66  memCrd.push_back(iv);
67  return genIndexLoad(b, l, getCrdBuf(), memCrd);
68  }
69 
70 protected:
71  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72  Value getPosBuf() const {
73  return buffers[0];
74  }
75 
76  Value getCrdBuf() const {
77  if constexpr (hasPosBuffer)
78  return buffers[1];
79  else
80  return buffers[0];
81  }
82 
83  const BufferT buffers;
84 };
85 
86 class DenseLevel : public SparseTensorLevel {
87 public:
88  DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89  : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90 
91  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92  llvm_unreachable("locate random-accessible level instead");
93  }
94 
95  ValueRange getLvlBuffers() const override { return {}; }
96 
97  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98  ValueRange parentPos, Value inPadZone) const override {
99  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100  assert(!inPadZone && "Not implemented");
101  Value p = parentPos.front();
102  Value posLo = MULI(p, lvlSize);
103  return {posLo, lvlSize};
104  }
105 };
106 
107 class BatchLevel : public SparseTensorLevel {
108 public:
109  BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110  : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111 
112  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113  llvm_unreachable("locate random-accessible level instead");
114  }
115 
116  ValueRange getLvlBuffers() const override { return {}; }
117 
118  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119  ValueRange parentPos, Value inPadZone) const override {
120  assert(!inPadZone && "Not implemented");
121  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122  // No need to linearize the position for non-annotated tensors.
123  return {C_IDX(0), lvlSize};
124  }
125 };
126 
127 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128 public:
129  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130  Value posBuffer, Value crdBuffer)
131  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132 
133  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134  ValueRange parentPos, Value inPadZone) const override {
135 
136  assert(parentPos.size() == 1 &&
137  "compressed level must be the first non-unique level.");
138 
139  auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140  Value p = parentPos.front();
141  SmallVector<Value> memCrd(batchPrefix);
142  memCrd.push_back(p);
143  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144  memCrd.back() = ADDI(p, C_IDX(1));
145  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146  return {pLo, pHi};
147  };
148 
149  if (inPadZone == nullptr)
150  return loadRange();
151 
153  scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154  // True branch, returns a "fake" empty range [0, 0) if parent
155  // iterator is in pad zone.
156  b.setInsertionPointToStart(posRangeIf.thenBlock());
157 
158  SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159  b.create<scf::YieldOp>(l, emptyRange);
160 
161  // False branch, returns the actual range.
162  b.setInsertionPointToStart(posRangeIf.elseBlock());
163  auto [pLo, pHi] = loadRange();
164  SmallVector<Value, 2> loadedRange{pLo, pHi};
165  b.create<scf::YieldOp>(l, loadedRange);
166 
167  b.setInsertionPointAfter(posRangeIf);
168  ValueRange posRange = posRangeIf.getResults();
169  return {posRange.front(), posRange.back()};
170  }
171 }; // namespace
172 
173 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174 public:
175  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176  Value posBuffer, Value crdBuffer)
177  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178 
179  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180  ValueRange parentPos, Value inPadZone) const override {
181  assert(parentPos.size() == 1 &&
182  "loose-compressed level must be the first non-unique level.");
183  assert(!inPadZone && "Not implemented");
184  SmallVector<Value> memCrd(batchPrefix);
185  Value p = parentPos.front();
186  p = MULI(p, C_IDX(2));
187  memCrd.push_back(p);
188  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189  memCrd.back() = ADDI(p, C_IDX(1));
190  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191  return {pLo, pHi};
192  }
193 }; // namespace
194 
195 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196 public:
197  SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198  Value crdBuffer)
199  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200 
201  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202  ValueRange parentPos, Value inPadZone) const override {
203  assert(parentPos.size() == 1 || parentPos.size() == 2);
204  assert(!inPadZone && "Not implemented");
205  Value p = parentPos.front();
206  Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207 
208  if (segHi == nullptr)
209  return {p, ADDI(p, C_IDX(1))};
210  // Use the segHi as the loop upper bound.
211  return {p, segHi};
212  }
213 
214  ValuePair
215  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216  std::pair<Value, Value> parentRange) const override {
217  // Singleton level keeps the same range after collapsing.
218  return parentRange;
219  };
220 };
221 
222 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223 public:
224  NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225  Value crdBuffer)
226  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227 
228  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229  ValueRange parentPos, Value inPadZone) const override {
230  assert(parentPos.size() == 1 && isUnique() &&
231  "n:m level can not be non-unique.");
232  assert(!inPadZone && "Not implemented");
233  // Each n:m blk has exactly n specified elements.
234  auto n = getN(lt);
235  Value posLo = MULI(parentPos.front(), C_IDX(n));
236  return {posLo, ADDI(posLo, C_IDX(n))};
237  }
238 };
239 
240 } // namespace
241 
242 //===----------------------------------------------------------------------===//
243 // File local helpers
244 //===----------------------------------------------------------------------===//
245 
247  OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
249  builder) {
250  TypeRange ifRetTypes = elseRet.getTypes();
251  auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
252 
253  b.setInsertionPointToStart(ifOp.thenBlock());
254  Value crd = it.deref(b, l);
255  scf::ValueVector ret = builder(b, l, crd);
256  YIELD(ret);
257 
258  b.setInsertionPointToStart(ifOp.elseBlock());
259  YIELD(elseRet);
260 
261  b.setInsertionPointAfter(ifOp);
262  return ifOp.getResults();
263 }
264 
265 /// Generates code to compute the *absolute* offset of the slice based on the
266 /// provide minimum coordinates in the slice.
267 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269 /// offset is the offset computed relative to the initial tensors T.
270 ///
271 /// When isNonEmpty == true, the computed offset is meaningless and should not
272 /// be used during runtime, the method generates code to return 0 currently in
273 /// that case.
274 ///
275 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
277  Value size) {
278  Value geSize = CMPI(uge, minCrd, size);
279  // Compute minCrd - size + 1.
280  Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281  // This is the absolute offset related to the actual tensor.
282  return SELECT(geSize, mms, C_IDX(0));
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // SparseIterator derived classes.
287 //===----------------------------------------------------------------------===//
288 
289 namespace {
290 
291 // The iterator that traverses a concrete sparse tensor levels. High-level
292 // abstract iterators wrap it to achieve more complex goals (such as collapsing
293 // several levels). It also holds the common storage to hold the mlir::Values
294 // for itself as well as for wrappers.
295 class ConcreteIterator : public SparseIterator {
296 protected:
297  ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298  unsigned cursorValCnt)
299  : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300  stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301  assert(getCursor().size() == cursorValCnt);
302  };
303 
304 public:
305  // For LLVM-style RTTI.
306  static bool classof(const SparseIterator *from) {
307  return from->kind == IterKind::kTrivial;
308  }
309 
310  bool isBatchIterator() const override {
311  return stl.getLT().isa<LevelFormat::Batch>();
312  }
313  bool randomAccessible() const override {
314  return stl.getLT().hasDenseSemantic();
315  };
316  bool iteratableByFor() const override { return kind != IterKind::kDedup; };
317  Value upperBound(OpBuilder &b, Location l) const override {
318  return stl.getSize();
319  };
320 
321 protected:
322  const SparseTensorLevel &stl;
323  // Owner of the storage, all wrappers build on top of a concrete iterator
324  // share the same storage such that the iterator values are always
325  // synchronized.
326  SmallVector<Value> cursorValsStorage;
327 };
328 
329 class TrivialIterator : public ConcreteIterator {
330 public:
331  TrivialIterator(const SparseTensorLevel &stl)
332  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333 
334  TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335  Value posLo, Value posHi)
336  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337  posHi(posHi) {
338  seek(posLo);
339  }
340 
341  std::string getDebugInterfacePrefix() const override {
342  return std::string("trivial<") + stl.toString() + ">";
343  }
344  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345  return {b.getIndexType()};
346  }
347 
348  SmallVector<Value> serialize() const override {
349  SmallVector<Value> ret;
350  ret.push_back(getItPos());
351  if (randomAccessible()) {
352  // Loop high is implicit (defined by `upperBound()`) for random-access
353  // iterator, but we need to memorize posLo for linearization.
354  ret.push_back(posLo);
355  } else {
356  ret.push_back(posHi);
357  }
358  return ret;
359  };
360 
361  void deserialize(ValueRange vs) override {
362  assert(vs.size() == 2);
363  seek(vs.front());
364  if (randomAccessible())
365  posLo = vs.back();
366  else
367  posHi = vs.back();
368  };
369 
370  void genInitImpl(OpBuilder &b, Location l,
371  const SparseIterator *parent) override;
372 
373  ValuePair genForCond(OpBuilder &b, Location l) override {
374  if (randomAccessible())
375  return {deref(b, l), upperBound(b, l)};
376  return std::make_pair(getItPos(), posHi);
377  }
378 
379  Value genNotEndImpl(OpBuilder &b, Location l) override {
380  // We used the first level bound as the bound the collapsed set of levels.
381  return CMPI(ult, getItPos(), posHi);
382  }
383 
384  Value derefImpl(OpBuilder &b, Location l) override {
385  if (randomAccessible()) {
386  updateCrd(SUBI(getItPos(), posLo));
387  } else {
388  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
389  }
390  return getCrd();
391  };
392 
393  ValueRange forwardImpl(OpBuilder &b, Location l) override {
394  seek(ADDI(getItPos(), C_IDX(1)));
395  return getCursor();
396  }
397 
398  ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399  Value curPos = getCursor().front();
400  Value nxPos = forward(b, l).front();
401  seek(SELECT(cond, nxPos, curPos));
402  return getCursor();
403  }
404 
405  void locateImpl(OpBuilder &b, Location l, Value crd) override {
406  assert(randomAccessible());
407  // Seek to the linearized position.
408  seek(ADDI(crd, posLo));
409  updateCrd(crd);
410  if (isBatchIterator()) {
411  // If this is a batch iterator, also update the batch coordinate.
412  assert(batchCrds.size() > lvl);
413  batchCrds[lvl] = crd;
414  }
415  }
416 
417  Value getItPos() const { return getCursor().front(); }
418  Value posLo, posHi;
419 };
420 
421 class DedupIterator : public ConcreteIterator {
422 private:
423  Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424 
425 public:
426  DedupIterator(const SparseTensorLevel &stl)
427  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428  assert(!stl.isUnique());
429  }
430 
431  DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432  Value posLo, Value posHi)
433  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434  assert(!stl.isUnique());
435  seek({posLo, genSegmentHigh(b, l, posLo)});
436  }
437 
438  // For LLVM-style RTTI.
439  static bool classof(const SparseIterator *from) {
440  return from->kind == IterKind::kDedup;
441  }
442 
443  std::string getDebugInterfacePrefix() const override {
444  return std::string("dedup<") + stl.toString() + ">";
445  }
446  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447  return {b.getIndexType(), b.getIndexType()};
448  }
449 
450  void genInitImpl(OpBuilder &b, Location l,
451  const SparseIterator *parent) override {
452  Value c0 = C_IDX(0);
453  ValueRange pPos = c0;
454 
455  // If the parent iterator is a batch iterator, we also start from 0 (but
456  // on a different batch).
457  if (parent && !parent->isBatchIterator())
458  pPos = parent->getCurPosition();
459 
460  Value posLo;
461  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
463 
464  seek({posLo, genSegmentHigh(b, l, posLo)});
465  }
466 
467  SmallVector<Value> serialize() const override {
468  SmallVector<Value> ret;
469  ret.append(getCursor().begin(), getCursor().end());
470  ret.push_back(posHi);
471  return ret;
472  };
473  void deserialize(ValueRange vs) override {
474  assert(vs.size() == 3);
475  seek(vs.take_front(getCursor().size()));
476  posHi = vs.back();
477  };
478 
479  Value genNotEndImpl(OpBuilder &b, Location l) override {
480  return CMPI(ult, getPos(), posHi);
481  }
482 
483  Value derefImpl(OpBuilder &b, Location l) override {
484  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
485  return getCrd();
486  };
487 
488  ValueRange forwardImpl(OpBuilder &b, Location l) override {
489  Value nxPos = getSegHi(); // forward the position to the next segment.
490  seek({nxPos, genSegmentHigh(b, l, nxPos)});
491  return getCursor();
492  }
493 
494  Value getPos() const { return getCursor()[0]; }
495  Value getSegHi() const { return getCursor()[1]; }
496 
497  Value posHi;
498 };
499 
500 // A util base-iterator that delegates all methods to the wrapped iterator.
501 class SimpleWrapIterator : public SparseIterator {
502 public:
503  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504  unsigned extraCursorVal = 0)
505  : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506 
507  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
508  return wrap->getCursorValTypes(b);
509  }
510  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
511  bool randomAccessible() const override { return wrap->randomAccessible(); };
512  bool iteratableByFor() const override { return wrap->iteratableByFor(); };
513 
514  SmallVector<Value> serialize() const override { return wrap->serialize(); };
515  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
516  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
517  void genInitImpl(OpBuilder &b, Location l,
518  const SparseIterator *parent) override {
519  wrap->genInit(b, l, parent);
520  }
521  Value genNotEndImpl(OpBuilder &b, Location l) override {
522  return wrap->genNotEndImpl(b, l);
523  }
524  ValueRange forwardImpl(OpBuilder &b, Location l) override {
525  return wrap->forward(b, l);
526  };
527  Value upperBound(OpBuilder &b, Location l) const override {
528  return wrap->upperBound(b, l);
529  };
530 
531  Value derefImpl(OpBuilder &b, Location l) override {
532  return wrap->derefImpl(b, l);
533  }
534 
535  void locateImpl(OpBuilder &b, Location l, Value crd) override {
536  return wrap->locate(b, l, crd);
537  }
538 
539  SparseIterator &getWrappedIterator() const { return *wrap; }
540 
541 protected:
542  std::unique_ptr<SparseIterator> wrap;
543 };
544 
545 //
546 // A filter iterator wrapped from another iterator. The filter iterator update
547 // the wrapped iterator *in-place*.
548 //
549 class FilterIterator : public SimpleWrapIterator {
550  // Coorindate translation between crd loaded from the wrap iterator and the
551  // filter iterator.
552  Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
553  // crd = (wrapCrd - offset) / stride
554  return DIVUI(SUBI(wrapCrd, offset), stride);
555  }
556  Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
557  // wrapCrd = crd * stride + offset
558  return ADDI(MULI(crd, stride), offset);
559  }
560 
561  Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
562 
563  Value genShouldFilter(OpBuilder &b, Location l);
564 
565 public:
566  // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
567  // when crd always < size.
568  FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
569  Value stride, Value size)
570  : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
571  stride(stride), size(size) {}
572 
573  // For LLVM-style RTTI.
574  static bool classof(const SparseIterator *from) {
575  return from->kind == IterKind::kFilter;
576  }
577 
578  std::string getDebugInterfacePrefix() const override {
579  return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
580  }
581 
582  bool iteratableByFor() const override { return randomAccessible(); };
583  Value upperBound(OpBuilder &b, Location l) const override { return size; };
584 
585  void genInitImpl(OpBuilder &b, Location l,
586  const SparseIterator *parent) override {
587  wrap->genInit(b, l, parent);
588  if (!randomAccessible()) {
589  // TODO: we can skip this when stride == 1 and offset == 0, we can also
590  // use binary search here.
591  forwardIf(b, l, genShouldFilter(b, l));
592  } else {
593  // Else, locate to the slice.offset, which is the first coordinate
594  // included by the slice.
595  wrap->locate(b, l, offset);
596  }
597  }
598 
599  Value genNotEndImpl(OpBuilder &b, Location l) override;
600 
601  Value derefImpl(OpBuilder &b, Location l) override {
602  updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
603  return getCrd();
604  }
605 
606  void locateImpl(OpBuilder &b, Location l, Value crd) override {
607  assert(randomAccessible());
608  wrap->locate(b, l, toWrapCrd(b, l, crd));
609  updateCrd(crd);
610  }
611 
612  ValueRange forwardImpl(OpBuilder &b, Location l) override;
613 
614  Value offset, stride, size;
615 };
616 
617 //
618 // A pad iterator wrapped from another iterator. The pad iterator updates
619 // the wrapped iterator *in-place*.
620 //
621 class PadIterator : public SimpleWrapIterator {
622 
623 public:
624  PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
625  Value padHigh)
626  : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
627  wrap->randomAccessible() ? 1 : 0),
628  padLow(padLow), padHigh(padHigh) {}
629 
630  // For LLVM-style RTTI.
631  static bool classof(const SparseIterator *from) {
632  return from->kind == IterKind::kPad;
633  }
634 
635  std::string getDebugInterfacePrefix() const override {
636  return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
637  }
638 
639  // Returns a pair of values for *upper*, *lower* bound respectively.
640  ValuePair genForCond(OpBuilder &b, Location l) override {
641  if (randomAccessible())
642  return {getCrd(), upperBound(b, l)};
643  return wrap->genForCond(b, l);
644  }
645 
646  // For padded dense iterator, we append a `inPadZone: bool` in addition to
647  // values used by the wrapped iterator.
648  ValueRange getCurPosition() const override { return getCursor(); }
649 
650  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
651  SmallVector<Type> ret = wrap->getCursorValTypes(b);
652  // Need an extra boolean value `inPadZone` for padded dense iterator.
653  if (randomAccessible())
654  ret.push_back(b.getI1Type());
655 
656  return ret;
657  }
658 
659  // The upper bound after padding becomes `size + padLow + padHigh`.
660  Value upperBound(OpBuilder &b, Location l) const override {
661  return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
662  };
663 
664  // The pad_coord = coord + pad_lo
665  Value derefImpl(OpBuilder &b, Location l) override {
666  updateCrd(ADDI(wrap->deref(b, l), padLow));
667  return getCrd();
668  }
669 
670  void locateImpl(OpBuilder &b, Location l, Value crd) override {
671  assert(randomAccessible());
672  wrap->locate(b, l, SUBI(crd, padLow));
673 
674  // inPadZone = crd < padLow || crd >= size + padLow.
675  Value inPadLow = CMPI(ult, crd, padLow);
676  Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
677  getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
678 
679  updateCrd(crd);
680  }
681 
682  Value padLow, padHigh;
683 };
684 
685 class NonEmptySubSectIterator : public SparseIterator {
686 public:
687  using TraverseBuilder = llvm::function_ref<scf::ValueVector(
689 
690  NonEmptySubSectIterator(OpBuilder &b, Location l,
691  const SparseIterator *parent,
692  std::unique_ptr<SparseIterator> &&delegate,
693  Value subSectSz)
694  : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
695  parent(parent), delegate(std::move(delegate)),
696  tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
697  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
698  if (p == nullptr) {
699  // Extract subsections along the root level.
700  maxTupleCnt = C_IDX(1);
701  } else if (p->lvl == lvl) {
702  // Extract subsections along the same level.
703  maxTupleCnt = p->maxTupleCnt;
704  assert(false && "Not implemented.");
705  } else {
706  // Extract subsections along the previous level.
707  assert(p->lvl + 1 == lvl);
708  maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
709  }
710  // We don't need an extra buffer to find subsections on random-accessible
711  // levels.
712  if (randomAccessible())
713  return;
714  subSectPosBuf = allocSubSectPosBuf(b, l);
715  }
716 
717  // For LLVM-style RTTI.
718  static bool classof(const SparseIterator *from) {
719  return from->kind == IterKind::kNonEmptySubSect;
720  }
721 
722  std::string getDebugInterfacePrefix() const override {
723  return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
724  }
725  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
726  // minCrd, absolute offset, notEnd
727  return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
728  }
729 
730  // The sliced pointer buffer is organized as:
731  // [[itVal0, itVal1, ..., pNx0],
732  // [itVal0, itVal1, ..., pNx0],
733  // ...]
734  Value allocSubSectPosBuf(OpBuilder &b, Location l) {
735  return b.create<memref::AllocaOp>(
736  l,
737  MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
738  maxTupleCnt);
739  }
740 
741  void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
742  Value start) const {
743  b.create<memref::StoreOp>(l, start, subSectPosBuf,
744  ValueRange{tupleId, C_IDX(tupleSz)});
745  }
746 
747  Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
748  return b.create<memref::LoadOp>(l, subSectPosBuf,
749  ValueRange{tupleId, C_IDX(tupleSz)});
750  }
751 
752  void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
753  ValueRange itVals) const {
754  assert(itVals.size() == tupleSz);
755  for (unsigned i = 0; i < tupleSz; i++) {
756  b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
757  ValueRange{tupleId, C_IDX(i)});
758  }
759  }
760 
761  SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
762  Value tupleId) const {
763  SmallVector<Value> ret;
764  for (unsigned i = 0; i < tupleSz; i++) {
765  Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
766  ValueRange{tupleId, C_IDX(i)});
767  ret.push_back(v);
768  }
769  return ret;
770  }
771 
772  bool isSubSectRoot() const {
773  return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
774  }
775 
776  // Generate code that inflate the current subsection tree till the current
777  // level such that every leaf node is visited.
778  ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
779  TraverseBuilder builder) const;
780 
781  bool isBatchIterator() const override { return delegate->isBatchIterator(); }
782  bool randomAccessible() const override {
783  return delegate->randomAccessible();
784  };
785  bool iteratableByFor() const override { return randomAccessible(); };
786  Value upperBound(OpBuilder &b, Location l) const override {
787  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
788  Value parentUB =
789  p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
790  return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
791  };
792 
793  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
794 
795  void locateImpl(OpBuilder &b, Location l, Value crd) override {
796  Value absOff = crd;
797 
798  if (isSubSectRoot())
799  delegate->locate(b, l, absOff);
800  else
801  assert(parent->lvl + 1 == lvl);
802 
803  seek(ValueRange{absOff, absOff, C_TRUE});
804  updateCrd(crd);
805  }
806 
807  Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
808  return SUBI(wrapCrd, getAbsOff());
809  }
810 
811  Value genNotEndImpl(OpBuilder &b, Location l) override {
812  return getNotEnd();
813  };
814 
815  Value derefImpl(OpBuilder &b, Location l) override {
816  // Use the relative offset to coiterate.
817  Value crd;
818  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
819  if (p && p->lvl == lvl)
820  crd = SUBI(getAbsOff(), p->getAbsOff());
821  crd = getAbsOff();
822 
823  updateCrd(crd);
824  return crd;
825  };
826 
827  ValueRange forwardImpl(OpBuilder &b, Location l) override;
828 
829  Value getMinCrd() const { return subSectMeta[0]; }
830  Value getAbsOff() const { return subSectMeta[1]; }
831  Value getNotEnd() const { return subSectMeta[2]; }
832 
833  const SparseIterator *parent;
834  std::unique_ptr<SparseIterator> delegate;
835 
836  // Number of values required to serialize the wrapped iterator.
837  const unsigned tupleSz;
838  // Max number of tuples, and the actual number of tuple.
839  Value maxTupleCnt, tupleCnt;
840  // The memory used to cache the tuple serialized from the wrapped iterator.
841  Value subSectPosBuf;
842 
843  const Value subSectSz;
844 
845  // minCrd, absolute offset, notEnd
846  SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
847 };
848 
849 class SubSectIterator;
850 
851 // A wrapper that helps generating code to traverse a subsection, used
852 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
853 struct SubSectIterHelper {
854  explicit SubSectIterHelper(const SubSectIterator &iter);
855  explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
856 
857  // Delegate methods.
858  void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
859  void locate(OpBuilder &b, Location l, Value crd);
860  Value genNotEnd(OpBuilder &b, Location l);
861  Value deref(OpBuilder &b, Location l);
862  ValueRange forward(OpBuilder &b, Location l);
863 
864  const NonEmptySubSectIterator &subSect;
866 };
867 
868 class SubSectIterator : public SparseIterator {
869 public:
870  SubSectIterator(const NonEmptySubSectIterator &subSect,
871  const SparseIterator &parent,
872  std::unique_ptr<SparseIterator> &&wrap)
874  /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
875  subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
876  assert(subSect.tid == tid && subSect.lvl == lvl);
877  assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
878  };
879 
880  // For LLVM-style RTTI.
881  static bool classof(const SparseIterator *from) {
882  return from->kind == IterKind::kSubSect;
883  }
884 
885  std::string getDebugInterfacePrefix() const override {
886  return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
887  }
888  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
889  SmallVector<Type> ret = wrap->getCursorValTypes(b);
890  if (!randomAccessible())
891  ret.push_back(b.getIndexType()); // The extra counter.
892  return ret;
893  }
894 
895  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
896  bool randomAccessible() const override { return wrap->randomAccessible(); };
897  bool iteratableByFor() const override { return randomAccessible(); };
898  Value upperBound(OpBuilder &b, Location l) const override {
899  return subSect.subSectSz;
900  }
901 
902  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
903 
904  Value getNxLvlTupleId(OpBuilder &b, Location l) const {
905  if (randomAccessible()) {
906  return ADDI(getCrd(), nxLvlTupleStart);
907  };
908  return ADDI(getCursor().back(), nxLvlTupleStart);
909  }
910 
911  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
912  if (randomAccessible()) {
913  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
914  assert(p->lvl + 1 == lvl);
915  wrap->genInit(b, l, p);
916  // Linearize the dense subsection index.
917  nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
918  } else {
919  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
920  wrap->deserialize(subSect.delegate->serialize());
921  nxLvlTupleStart = C_IDX(0);
922  }
923  return;
924  }
925  assert(!randomAccessible());
926  assert(getCursor().size() == wrap->getCursor().size() + 1);
927  // Extra counter that counts the number of actually visited coordinates in
928  // the sparse subsection.
929  getMutCursorVals().back() = C_IDX(0);
930  Value tupleId;
931  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
932  assert(p->lvl + 1 == lvl);
933  tupleId = p->getNxLvlTupleId(b, l);
934  } else {
935  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
936  tupleId = C_IDX(0);
937  }
938  nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
939  helper.deserializeFromTupleId(b, l, tupleId);
940  }
941 
942  void locateImpl(OpBuilder &b, Location l, Value crd) override {
943  helper.locate(b, l, crd);
944  updateCrd(crd);
945  }
946 
947  Value genNotEndImpl(OpBuilder &b, Location l) override {
948  return helper.genNotEnd(b, l);
949  }
950 
951  Value derefImpl(OpBuilder &b, Location l) override {
952  Value crd = helper.deref(b, l);
953  updateCrd(crd);
954  return crd;
955  };
956 
957  ValueRange forwardImpl(OpBuilder &b, Location l) override {
958  helper.forward(b, l);
959  assert(!randomAccessible());
960  assert(getCursor().size() == wrap->getCursor().size() + 1);
961  getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
962  return getCursor();
963  };
964 
965  Value nxLvlTupleStart;
966 
967  const NonEmptySubSectIterator &subSect;
968  std::unique_ptr<SparseIterator> wrap;
969  const SparseIterator &parent;
970 
971  SubSectIterHelper helper;
972 };
973 
974 } // namespace
975 
976 //===----------------------------------------------------------------------===//
977 // SparseIterator derived classes implementation.
978 //===----------------------------------------------------------------------===//
979 
981  const SparseIterator *p) {
983  std::string prefix = getDebugInterfacePrefix();
984  Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
985  getCursorValTypes(b));
986  seek(begin->getResults());
987  return;
988  }
989  // Inherent batch coordinates from parents.
990  if (p)
991  inherentBatch(*p);
992  // TODO: support lowering to function call.
993  return genInitImpl(b, l, p);
994 }
995 
998  std::string prefix = getDebugInterfacePrefix();
999  Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1000  getCursor(), b.getI1Type());
1001  return notEnd->getResult(0);
1002  }
1003  // TODO: support lowering to function call.
1004  return genNotEndImpl(b, l);
1005 }
1006 
1009  std::string prefix = getDebugInterfacePrefix();
1010  SmallVector<Value> args = getCursor();
1011  args.push_back(crd);
1012  Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1013  getCursorValTypes(b));
1014  seek(locate->getResults());
1015  updateCrd(crd);
1016  return;
1017  }
1018  return locateImpl(b, l, crd);
1019 }
1020 
1023  std::string prefix = getDebugInterfacePrefix();
1024  SmallVector<Value> args = getCursor();
1025  Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1026  getCursor(), b.getIndexType());
1027  updateCrd(deref->getResult(0));
1028  return getCrd();
1029  }
1030  return derefImpl(b, l);
1031 }
1032 
1034  assert(!randomAccessible());
1036  std::string prefix = getDebugInterfacePrefix();
1037  Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1039  seek(next->getResults());
1040  return getCursor();
1041  }
1042  return forwardImpl(b, l);
1043 }
1044 
1046  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
1047  // Generate else branch first, otherwise iterator values will be updated by
1048  // `forward()`.
1049  b.setInsertionPointToStart(ifOp.elseBlock());
1050  YIELD(getCursor());
1051 
1052  b.setInsertionPointToStart(ifOp.thenBlock());
1053  YIELD(forward(b, l));
1054 
1055  b.setInsertionPointAfter(ifOp);
1056  seek(ifOp.getResults());
1057  return getCursor();
1058 }
1059 
1060 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1061  auto whileOp = b.create<scf::WhileOp>(
1062  l, pos.getType(), pos,
1063  /*beforeBuilder=*/
1064  [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1065  Value inBound = CMPI(ult, ivs.front(), posHi);
1066  auto ifInBound = b.create<scf::IfOp>(l, b.getI1Type(), inBound, true);
1067  {
1068  OpBuilder::InsertionGuard guard(b);
1069  // If in bound, load the next coordinates and check duplication.
1070  b.setInsertionPointToStart(ifInBound.thenBlock());
1071  Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1072  Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1073  Value isDup = CMPI(eq, headCrd, tailCrd);
1074  YIELD(isDup);
1075  // Else, the position is out of bound, yield false.
1076  b.setInsertionPointToStart(ifInBound.elseBlock());
1077  YIELD(constantI1(b, l, false));
1078  }
1079  b.create<scf::ConditionOp>(l, ifInBound.getResults()[0], ivs);
1080  },
1081  /*afterBuilder=*/
1082  [](OpBuilder &b, Location l, ValueRange ivs) {
1083  Value nxPos = ADDI(ivs[0], C_IDX(1));
1084  YIELD(nxPos);
1085  });
1086  // Return the segment high.
1087  return whileOp.getResult(0);
1088 }
1089 
1090 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1091  Value wrapCrd) {
1092  Value crd = fromWrapCrd(b, l, wrapCrd);
1093  // Test whether the coordinate is on stride.
1094  Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1095  // Test wrapCrd < offset
1096  notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1097  // Test crd >= length
1098  notlegit = ORI(CMPI(uge, crd, size), notlegit);
1099  return notlegit;
1100 }
1101 
1102 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1103  auto r = genWhenInBound(
1104  b, l, *wrap, C_FALSE,
1105  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1106  Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1107  return {notLegit};
1108  });
1109  return llvm::getSingleElement(r);
1110 }
1111 
1112 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1113  assert(!wrap->randomAccessible());
1114  auto r = genWhenInBound(
1115  b, l, *wrap, C_FALSE,
1116  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1117  Value crd = fromWrapCrd(b, l, wrapCrd);
1118  // crd < size
1119  return {CMPI(ult, crd, size)};
1120  });
1121  return llvm::getSingleElement(r);
1122 }
1123 
1124 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1125  assert(!randomAccessible());
1126  // Generates
1127  //
1128  // bool isFirst = true;
1129  // while !it.end() && (!legit(*it) || isFirst)
1130  // wrap ++;
1131  // isFirst = false;
1132  //
1133  // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1134  // flag here because `wrap++` might have a complex implementation (e.g., to
1135  // forward a subsection).
1136  Value isFirst = constantI1(b, l, true);
1137 
1138  SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1139  whileArgs.push_back(isFirst);
1140  auto whileOp = b.create<scf::WhileOp>(
1141  l, ValueRange(whileArgs).getTypes(), whileArgs,
1142  /*beforeBuilder=*/
1143  [this](OpBuilder &b, Location l, ValueRange ivs) {
1144  ValueRange isFirst = linkNewScope(ivs);
1145  scf::ValueVector cont =
1146  genWhenInBound(b, l, *wrap, C_FALSE,
1147  [this, isFirst](OpBuilder &b, Location l,
1148  Value wrapCrd) -> scf::ValueVector {
1149  // crd < size && !legit();
1150  Value notLegit =
1151  genCrdNotLegitPredicate(b, l, wrapCrd);
1152  Value crd = fromWrapCrd(b, l, wrapCrd);
1153  Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1154  ret = ORI(ret, llvm::getSingleElement(isFirst));
1155  return {ret};
1156  });
1157  b.create<scf::ConditionOp>(l, cont.front(), ivs);
1158  },
1159  /*afterBuilder=*/
1160  [this](OpBuilder &b, Location l, ValueRange ivs) {
1161  linkNewScope(ivs);
1162  wrap->forward(b, l);
1163  SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1164  yieldVals.push_back(constantI1(b, l, false));
1165  YIELD(yieldVals);
1166  });
1167 
1168  b.setInsertionPointAfter(whileOp);
1169  linkNewScope(whileOp.getResults());
1170  return getCursor();
1171 }
1172 
1173 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1174  : subSect(subSect), wrap(*subSect.delegate) {}
1175 
1176 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1177  : subSect(iter.subSect), wrap(*iter.wrap) {}
1178 
1179 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1180  Value tupleId) {
1181  assert(!subSect.randomAccessible());
1182  wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1183 }
1184 
1185 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1186  Value absCrd = ADDI(crd, subSect.getAbsOff());
1187  wrap.locate(b, l, absCrd);
1188 }
1189 
1190 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1191  assert(!wrap.randomAccessible());
1192  auto r = genWhenInBound(
1193  b, l, wrap, C_FALSE,
1194  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1195  Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1196  // crd < size
1197  return {CMPI(ult, crd, subSect.subSectSz)};
1198  });
1199  return llvm::getSingleElement(r);
1200 }
1201 
1202 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1203  Value wrapCrd = wrap.deref(b, l);
1204  Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1205  return crd;
1206 }
1207 
1208 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1209  return wrap.forward(b, l);
1210 }
1211 
1212 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1213  OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1214  // Set up the helper to help traverse a sparse subsection.
1215  SubSectIterHelper helper(*this);
1216  if (!randomAccessible()) {
1217  // The subsection tree have been expanded till the level and cached,
1218  // traverse all the leaves and expanded to the next level.
1219  SmallVector<Value> iterArgs;
1220  iterArgs.push_back(C_IDX(0));
1221  iterArgs.append(reduc.begin(), reduc.end());
1222  auto forEachLeaf = b.create<scf::ForOp>(
1223  l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1224  [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1225  ValueRange iterArgs) {
1226  // Deserialize the iterator at the cached position (tupleId).
1227  helper.deserializeFromTupleId(b, l, tupleId);
1228 
1229  Value cnt = iterArgs.front();
1230  // Record the number of leaf nodes included in the subsection.
1231  // The number indicates the starting tupleId for the next level that
1232  // is corresponding to the current node.
1233  helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1234 
1235  SmallVector<Value> whileArgs(helper.wrap.getCursor());
1236  whileArgs.append(iterArgs.begin(), iterArgs.end());
1237 
1238  auto whileOp = b.create<scf::WhileOp>(
1239  l, ValueRange(whileArgs).getTypes(), whileArgs,
1240  /*beforeBuilder=*/
1241  [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1242  helper.wrap.linkNewScope(ivs);
1243  b.create<scf::ConditionOp>(l, helper.genNotEnd(b, l), ivs);
1244  },
1245  /*afterBuilder=*/
1246  [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1247  ValueRange remIter = helper.wrap.linkNewScope(ivs);
1248  Value cnt = remIter.front();
1249  ValueRange userIter = remIter.drop_front();
1250  scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1251 
1252  SmallVector<Value> nxIter = helper.forward(b, l);
1253  nxIter.push_back(ADDI(cnt, C_IDX(1)));
1254  nxIter.append(userNx.begin(), userNx.end());
1255  YIELD(nxIter);
1256  });
1257  ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1258  YIELD(res);
1259  });
1260  return forEachLeaf.getResults().drop_front();
1261  }
1262 
1263  assert(randomAccessible());
1264  // Helper lambda that traverse the current dense subsection range.
1265  auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1266  const SparseIterator *parent,
1267  ValueRange reduc) {
1268  assert(!parent || parent->lvl + 1 == lvl);
1269  delegate->genInit(b, l, parent);
1270  auto forOp = b.create<scf::ForOp>(
1271  l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1272  [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1273  helper.locate(b, l, crd);
1274  scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1275  YIELD(nx);
1276  });
1277  return forOp.getResults();
1278  };
1279 
1280  if (isSubSectRoot()) {
1281  return visitDenseSubSect(b, l, parent, reduc);
1282  }
1283  // Else, this is not the root, recurse until root.
1284  auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1285  assert(p->lvl + 1 == lvl);
1286  return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1287 }
1288 
1289 void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1290  const SparseIterator *parent) {
1291 
1292  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1293  batchCrds.resize(stl.lvl + 1, nullptr);
1294 
1295  Value c0 = C_IDX(0);
1296  ValueRange pPos = c0;
1297  Value inPadZone = nullptr;
1298  // If the parent iterator is a batch iterator, we also start from 0 (but
1299  // on a different batch).
1300  if (parent && !parent->isBatchIterator()) {
1301  pPos = parent->getCurPosition();
1302  if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1303  // A padded dense iterator create "sparse" padded zone, which need to be
1304  // handled specially.
1305  inPadZone = pPos.back();
1306  pPos = pPos.drop_back();
1307  }
1308  }
1309 
1310  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1311  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1312  // Seek to the lowest position.
1313  seek(posLo);
1314 }
1315 
1316 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1317  const SparseIterator *) {
1318  Value c0 = C_IDX(0);
1319  if (!isSubSectRoot()) {
1320  assert(parent->lvl + 1 == lvl);
1321  if (randomAccessible()) {
1322  // We can not call wrap->genInit() here to initialize the wrapped
1323  // iterator, because the parent of the curent iterator is still
1324  // unresolved.
1325  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1326  return;
1327  }
1328 
1329  auto *p = cast<NonEmptySubSectIterator>(parent);
1330  SmallVector<Value, 3> reduc = {
1331  C_IDX(-1), // minCrd (max signless integer)
1332  c0, // tupleId
1333  };
1334 
1335  // Expand the subsection tree from the parent level to the current level.
1336  ValueRange result = p->inflateSubSectTree(
1337  b, l, reduc,
1338  [this](OpBuilder &b, Location l, const SparseIterator *parent,
1339  ValueRange reduc) -> scf::ValueVector {
1340  assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1341  Value minCrd = reduc.front();
1342  Value tupleId = reduc.back();
1343 
1344  // Initialize the subsection range.
1345  SubSectIterHelper helper(*this);
1346  helper.wrap.genInit(b, l, parent);
1347 
1348  // Update minCrd.
1349  minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1350  [minCrd](OpBuilder &b, Location l,
1351  Value crd) -> scf::ValueVector {
1352  Value min = MINUI(crd, minCrd);
1353  return {min};
1354  })
1355  .front();
1356 
1357  // Cache the sparse range.
1358  storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1359  tupleId = ADDI(tupleId, C_IDX(1));
1360  return {minCrd, tupleId};
1361  });
1362  assert(result.size() == 2);
1363  tupleCnt = result.back();
1364 
1365  Value minCrd = result.front();
1366  Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1367  Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1368  seek({minCrd, absOff, notEnd});
1369  return;
1370  }
1371 
1372  // This is the root level of the subsection, which means that it is resolved
1373  // to one node.
1374  assert(isSubSectRoot());
1375 
1376  // Initialize the position, the position marks the *lower bound* of the
1377  // subRange. The higher bound is determined by the size of the subsection.
1378  delegate->genInit(b, l, parent);
1379  if (randomAccessible()) {
1380  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1381  return;
1382  }
1383 
1384  // Only have one root node.
1385  tupleCnt = C_IDX(1);
1386  // Cache the sparse range.
1387  storeCursorVals(b, l, c0, delegate->serialize());
1388  SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1389  auto meta = genWhenInBound(
1390  b, l, *delegate, elseRet,
1391  [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1392  Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1393  return {crd, offset, C_TRUE};
1394  });
1395 
1396  seek(meta);
1397 }
1398 
1399 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1400  assert(!randomAccessible());
1401  Value c0 = C_IDX(0), c1 = C_IDX(1);
1402  // Forward to the next non empty slice by generating
1403  //
1404  // if (minCrd > offset) {
1405  // offset += 1
1406  // } else {
1407  // minCrd = nextMinInSlice();
1408  // offset = minCrd - size + 1;
1409  // }
1410  //
1411  // if (offset + size > parents.size)
1412  // isNonEmpty = false;
1413  Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1414  auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
1415  {
1416  OpBuilder::InsertionGuard guard(b);
1417  // Take the fast path
1418  // if (minCrd > offset)
1419  // offset += 1
1420  b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1421  Value nxOffset = ADDI(getAbsOff(), c1);
1422  YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1423 
1424  // else /*minCrd == offset*/ {
1425  // for (i = 0; i < tupleCnt; i++) {
1426  // wrap->deserialize(pos[i]);
1427  // minCrd=min(minCrd, *wrap);
1428  // }
1429  // offset = minCrd - size + 1;
1430  // }
1431  b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1432  SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1433  C_FALSE}; // isNotEnd
1434  auto loopNest = scf::buildLoopNest(
1435  b, l, c0, tupleCnt, c1, loopArgs,
1436  [this](OpBuilder &b, Location l, ValueRange ivs,
1437  ValueRange iterArgs) -> scf::ValueVector {
1438  Value tupleId = ivs.front();
1439  SubSectIterHelper helper(*this);
1440  helper.deserializeFromTupleId(b, l, tupleId);
1441 
1442  return genWhenInBound(
1443  b, l, *delegate, /*elseRet=*/iterArgs,
1444  [this, iterArgs, tupleId](OpBuilder &b, Location l,
1445  Value crd) -> scf::ValueVector {
1446  // if coord == minCrd
1447  // wrap->forward();
1448  Value isMin = CMPI(eq, crd, getMinCrd());
1449  delegate->forwardIf(b, l, isMin);
1450  // Update the forwarded iterator values if needed.
1451  auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
1452  b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1453  storeCursorVals(b, l, tupleId, delegate->serialize());
1454  b.setInsertionPointAfter(ifIsMin);
1455  // if (!wrap.end())
1456  // yield(min(nxMinCrd, *wrap), true)
1457  Value nxMin = iterArgs[0];
1458  return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1459  [nxMin](OpBuilder &b, Location l,
1460  Value crd) -> scf::ValueVector {
1461  Value nx = b.create<arith::MinUIOp>(
1462  l, crd, nxMin);
1463  return {nx, C_TRUE};
1464  });
1465  });
1466  });
1467 
1468  scf::ForOp forOp = loopNest.loops.front();
1469  b.setInsertionPointAfter(forOp);
1470 
1471  Value nxMinCrd = forOp.getResult(0);
1472  Value nxNotEnd = forOp.getResult(1);
1473  Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1474  YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1475  }
1476 
1477  Value nxMinCrd = ifOp.getResult(0);
1478  Value nxAbsOff = ifOp.getResult(1);
1479  Value nxNotEnd = ifOp.getResult(2);
1480 
1481  // We should at least forward the offset by one.
1482  Value minAbsOff = ADDI(getAbsOff(), c1);
1483  nxAbsOff = b.create<arith::MaxUIOp>(l, minAbsOff, nxAbsOff);
1484 
1485  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1486  // The coordinate should not exceeds the space upper bound.
1487  Value crd = deref(b, l);
1488  nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1489 
1490  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1491  return getCursor();
1492 }
1493 
1494 //===----------------------------------------------------------------------===//
1495 // SparseIterationSpace Implementation
1496 //===----------------------------------------------------------------------===//
1497 
1499  Location l, OpBuilder &b, Value t, unsigned tid,
1500  std::pair<Level, Level> lvlRange, ValueRange parentPos)
1501  : lvls() {
1502  auto [lvlLo, lvlHi] = lvlRange;
1503 
1504  Value c0 = C_IDX(0);
1505  if (parentPos.empty())
1506  parentPos = c0;
1507 
1508  for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1509  lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1510 
1511  bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1512  for (auto &lvl : getLvlRef().drop_front())
1513  bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1514 }
1515 
1517  IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1518  // Reconstruct every sparse tensor level.
1519  SparseIterationSpace space;
1520  for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1521  unsigned bufferCnt = 0;
1522  if (lt.isWithPosLT())
1523  bufferCnt++;
1524  if (lt.isWithCrdLT())
1525  bufferCnt++;
1526  // Sparse tensor buffers.
1527  ValueRange buffers = values.take_front(bufferCnt);
1528  values = values.drop_front(bufferCnt);
1529 
1530  // Level size.
1531  Value sz = values.front();
1532  values = values.drop_front();
1533  space.lvls.push_back(
1534  makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1535  }
1536  // Two bounds.
1537  space.bound = std::make_pair(values[0], values[1]);
1538  values = values.drop_front(2);
1539 
1540  // Must have consumed all values.
1541  assert(values.empty());
1542  return space;
1543 }
1544 
1545 std::unique_ptr<SparseIterator>
1547  return makeSimpleIterator(b, l, *this);
1548 }
1549 
1550 //===----------------------------------------------------------------------===//
1551 // SparseIterator factory functions.
1552 //===----------------------------------------------------------------------===//
1553 
1554 /// Helper function to create a TensorLevel object from given `tensor`.
1555 std::unique_ptr<SparseTensorLevel>
1557  unsigned t, Level l) {
1558  assert(lt.getNumBuffer() == b.size());
1559  switch (lt.getLvlFmt()) {
1560  case LevelFormat::Dense:
1561  return std::make_unique<DenseLevel>(t, l, sz);
1562  case LevelFormat::Batch:
1563  return std::make_unique<BatchLevel>(t, l, sz);
1565  return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1567  return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1569  return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1570  case LevelFormat::NOutOfM:
1571  return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1572  case LevelFormat::Undef:
1573  llvm_unreachable("undefined level format");
1574  }
1575  llvm_unreachable("unrecognizable level format");
1576 }
1577 
1578 std::unique_ptr<SparseTensorLevel>
1580  unsigned tid, Level lvl) {
1581  auto stt = getSparseTensorType(t);
1582 
1583  LevelType lt = stt.getLvlType(lvl);
1584  Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
1585  : b.create<tensor::DimOp>(l, t, lvl).getResult();
1586 
1587  SmallVector<Value, 2> buffers;
1588  if (lt.isWithPosLT()) {
1589  Value pos = b.create<ToPositionsOp>(l, t, lvl);
1590  buffers.push_back(pos);
1591  }
1592  if (lt.isWithCrdLT()) {
1593  Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
1594  buffers.push_back(pos);
1595  }
1596  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
1597 }
1598 
1599 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1600 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1601  SparseEmitStrategy strategy) {
1602  auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1603  auto it = std::make_unique<TrivialIterator>(*stl);
1604  it->setSparseEmitStrategy(strategy);
1605  return std::make_pair(std::move(stl), std::move(it));
1606 }
1607 
1608 std::unique_ptr<SparseIterator>
1610  const SparseIterationSpace &iterSpace) {
1611  // assert(iterSpace.getSpaceDim() == 1);
1612  std::unique_ptr<SparseIterator> ret;
1613  if (!iterSpace.isUnique()) {
1614  // We always dedupliate the non-unique level, but we should optimize it away
1615  // if possible.
1616  ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1617  iterSpace.getBoundLo(),
1618  iterSpace.getBoundHi());
1619  } else {
1620  ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1621  iterSpace.getBoundLo(),
1622  iterSpace.getBoundHi());
1623  }
1624  ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1625  return ret;
1626 }
1627 
1628 std::unique_ptr<SparseIterator>
1630  SparseEmitStrategy strategy) {
1631  std::unique_ptr<SparseIterator> ret;
1632  if (!isUniqueLT(stl.getLT())) {
1633  // We always dedupliate the non-unique level, but we should optimize it away
1634  // if possible.
1635  ret = std::make_unique<DedupIterator>(stl);
1636  } else {
1637  ret = std::make_unique<TrivialIterator>(stl);
1638  }
1639  ret->setSparseEmitStrategy(strategy);
1640  return ret;
1641 }
1642 
1643 std::unique_ptr<SparseIterator>
1644 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1645  Value offset, Value stride, Value size,
1646  SparseEmitStrategy strategy) {
1647 
1648  auto ret =
1649  std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1650  ret->setSparseEmitStrategy(strategy);
1651  return ret;
1652 }
1653 
1654 std::unique_ptr<SparseIterator>
1655 sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1656  Value padLow, Value padHigh,
1657  SparseEmitStrategy strategy) {
1658  auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1659  ret->setSparseEmitStrategy(strategy);
1660  return ret;
1661 }
1662 
1664  auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1665  if (filter)
1666  return &filter->getWrappedIterator();
1667  return it;
1668 }
1669 
1670 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1671  OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1672  std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1673  SparseEmitStrategy strategy) {
1674 
1675  // Try unwrap the NonEmptySubSectIterator from a filter parent.
1676  parent = tryUnwrapFilter(parent);
1677  std::unique_ptr<SparseIterator> it =
1678  std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1679  std::move(delegate), size);
1680 
1681  if (stride != 1) {
1682  // TODO: We can safely skip bound checking on sparse levels, but for dense
1683  // iteration space, we need the bound to infer the dense loop range.
1684  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1685  C_IDX(stride), /*size=*/loopBound);
1686  }
1687  it->setSparseEmitStrategy(strategy);
1688  return it;
1689 }
1690 
1691 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1692  OpBuilder &b, Location l, const SparseIterator &subSectIter,
1693  const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1694  Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1695 
1696  // This must be a subsection iterator or a filtered subsection iterator.
1697  auto &subSect =
1698  llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1699 
1700  std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1701  subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1702 
1703  if (stride != 1) {
1704  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1705  C_IDX(stride), /*size=*/loopBound);
1706  }
1707  it->setSparseEmitStrategy(strategy);
1708  return it;
1709 }
1710 
1711 #undef CMPI
1712 #undef C_IDX
1713 #undef YIELD
1714 #undef ADDI
1715 #undef ANDI
1716 #undef SUBI
1717 #undef MULI
1718 #undef SELECT
union mlir::linalg::@1191::ArityGroupAndKind::Kind kind
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
#define SELECT(c, lhs, rhs)
#define C_FALSE
#define SUBI(lhs, rhs)
#define MULI(lhs, rhs)
#define C_IDX(v)
#define ANDI(lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
#define C_TRUE
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
#define YIELD(vs)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
#define ORI(lhs, rhs)
#define DIVUI(lhs, rhs)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
#define ADDI(lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
const SparseTensorLevel & getLastLvl() const
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:692
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
bool isUniqueLT(LevelType lt)
Definition: Enums.h:428
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
uint64_t getN(LevelType lt)
Definition: Enums.h:442
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context)
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
LogicalResult serialize(ModuleOp module, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V module and writes to binary.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:52
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LoopVector loops
Definition: SCF.h:67
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr unsigned getNumBuffer() const
Definition: Enums.h:360
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
Definition: Enums.h:320
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
Definition: Enums.h:348
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.
Definition: Enums.h:354