MLIR  22.0.0git
SparseTensorIterator.cpp
Go to the documentation of this file.
1 //===- SparseTensorIterator.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SparseTensorIterator.h"
10 #include "CodegenUtils.h"
11 
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 using ValuePair = std::pair<Value, Value>;
19 using ValueTuple = std::tuple<Value, Value, Value>;
20 
21 //===----------------------------------------------------------------------===//
22 // File local helper functions/macros.
23 //===----------------------------------------------------------------------===//
24 #define CMPI(p, lhs, rhs) \
25  (arith::CmpIOp::create(b, l, arith::CmpIPredicate::p, (lhs), (rhs)) \
26  .getResult())
27 
28 #define C_FALSE (constantI1(b, l, false))
29 #define C_TRUE (constantI1(b, l, true))
30 #define C_IDX(v) (constantIndex(b, l, (v)))
31 #define YIELD(vs) (scf::YieldOp::create(b, l, (vs)))
32 #define ADDI(lhs, rhs) (arith::AddIOp::create(b, l, (lhs), (rhs)).getResult())
33 #define ORI(lhs, rhs) (arith::OrIOp::create(b, l, (lhs), (rhs)).getResult())
34 #define ANDI(lhs, rhs) (arith::AndIOp::create(b, l, (lhs), (rhs)).getResult())
35 #define SUBI(lhs, rhs) (arith::SubIOp::create(b, l, (lhs), (rhs)).getResult())
36 #define MULI(lhs, rhs) (arith::MulIOp::create(b, l, (lhs), (rhs)).getResult())
37 #define MINUI(lhs, rhs) (arith::MinUIOp::create(b, l, (lhs), (rhs)).getResult())
38 #define REMUI(lhs, rhs) (arith::RemUIOp::create(b, l, (lhs), (rhs)).getResult())
39 #define DIVUI(lhs, rhs) (arith::DivUIOp::create(b, l, (lhs), (rhs)).getResult())
40 #define SELECT(c, lhs, rhs) \
41  (arith::SelectOp::create(b, l, (c), (lhs), (rhs)).getResult())
42 
43 //===----------------------------------------------------------------------===//
44 // SparseTensorLevel derived classes.
45 //===----------------------------------------------------------------------===//
46 
47 namespace {
48 
49 template <bool hasPosBuffer>
50 class SparseLevel : public SparseTensorLevel {
51  // It is either an array of size 2 or size 1 depending on whether the sparse
52  // level requires a position array.
53  using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
54  std::array<Value, 1>>;
55 
56 public:
57  SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
58  BufferT buffers)
59  : SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}
60 
61  ValueRange getLvlBuffers() const override { return buffers; }
62 
63  Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
64  Value iv) const override {
65  SmallVector<Value> memCrd(batchPrefix);
66  memCrd.push_back(iv);
67  return genIndexLoad(b, l, getCrdBuf(), memCrd);
68  }
69 
70 protected:
71  template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
72  Value getPosBuf() const {
73  return buffers[0];
74  }
75 
76  Value getCrdBuf() const {
77  if constexpr (hasPosBuffer)
78  return buffers[1];
79  else
80  return buffers[0];
81  }
82 
83  const BufferT buffers;
84 };
85 
86 class DenseLevel : public SparseTensorLevel {
87 public:
88  DenseLevel(unsigned tid, Level lvl, Value lvlSize)
89  : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
90 
91  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
92  llvm_unreachable("locate random-accessible level instead");
93  }
94 
95  ValueRange getLvlBuffers() const override { return {}; }
96 
97  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98  ValueRange parentPos, Value inPadZone) const override {
99  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100  assert(!inPadZone && "Not implemented");
101  Value p = parentPos.front();
102  Value posLo = MULI(p, lvlSize);
103  return {posLo, lvlSize};
104  }
105 };
106 
107 class BatchLevel : public SparseTensorLevel {
108 public:
109  BatchLevel(unsigned tid, Level lvl, Value lvlSize)
110  : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
111 
112  Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
113  llvm_unreachable("locate random-accessible level instead");
114  }
115 
116  ValueRange getLvlBuffers() const override { return {}; }
117 
118  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119  ValueRange parentPos, Value inPadZone) const override {
120  assert(!inPadZone && "Not implemented");
121  assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122  // No need to linearize the position for non-annotated tensors.
123  return {C_IDX(0), lvlSize};
124  }
125 };
126 
127 class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
128 public:
129  CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
130  Value posBuffer, Value crdBuffer)
131  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
132 
133  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
134  ValueRange parentPos, Value inPadZone) const override {
135 
136  assert(parentPos.size() == 1 &&
137  "compressed level must be the first non-unique level.");
138 
139  auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140  Value p = parentPos.front();
141  SmallVector<Value> memCrd(batchPrefix);
142  memCrd.push_back(p);
143  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144  memCrd.back() = ADDI(p, C_IDX(1));
145  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146  return {pLo, pHi};
147  };
148 
149  if (inPadZone == nullptr)
150  return loadRange();
151 
153  scf::IfOp posRangeIf = scf::IfOp::create(b, l, types, inPadZone, true);
154  // True branch, returns a "fake" empty range [0, 0) if parent
155  // iterator is in pad zone.
156  b.setInsertionPointToStart(posRangeIf.thenBlock());
157 
158  SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159  scf::YieldOp::create(b, l, emptyRange);
160 
161  // False branch, returns the actual range.
162  b.setInsertionPointToStart(posRangeIf.elseBlock());
163  auto [pLo, pHi] = loadRange();
164  SmallVector<Value, 2> loadedRange{pLo, pHi};
165  scf::YieldOp::create(b, l, loadedRange);
166 
167  b.setInsertionPointAfter(posRangeIf);
168  ValueRange posRange = posRangeIf.getResults();
169  return {posRange.front(), posRange.back()};
170  }
171 }; // namespace
172 
173 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
174 public:
175  LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
176  Value posBuffer, Value crdBuffer)
177  : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
178 
179  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
180  ValueRange parentPos, Value inPadZone) const override {
181  assert(parentPos.size() == 1 &&
182  "loose-compressed level must be the first non-unique level.");
183  assert(!inPadZone && "Not implemented");
184  SmallVector<Value> memCrd(batchPrefix);
185  Value p = parentPos.front();
186  p = MULI(p, C_IDX(2));
187  memCrd.push_back(p);
188  Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
189  memCrd.back() = ADDI(p, C_IDX(1));
190  Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
191  return {pLo, pHi};
192  }
193 }; // namespace
194 
195 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
196 public:
197  SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
198  Value crdBuffer)
199  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
200 
201  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
202  ValueRange parentPos, Value inPadZone) const override {
203  assert(parentPos.size() == 1 || parentPos.size() == 2);
204  assert(!inPadZone && "Not implemented");
205  Value p = parentPos.front();
206  Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
207 
208  if (segHi == nullptr)
209  return {p, ADDI(p, C_IDX(1))};
210  // Use the segHi as the loop upper bound.
211  return {p, segHi};
212  }
213 
214  ValuePair
215  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
216  std::pair<Value, Value> parentRange) const override {
217  // Singleton level keeps the same range after collapsing.
218  return parentRange;
219  };
220 };
221 
222 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
223 public:
224  NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
225  Value crdBuffer)
226  : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
227 
228  ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
229  ValueRange parentPos, Value inPadZone) const override {
230  assert(parentPos.size() == 1 && isUnique() &&
231  "n:m level can not be non-unique.");
232  assert(!inPadZone && "Not implemented");
233  // Each n:m blk has exactly n specified elements.
234  auto n = getN(lt);
235  Value posLo = MULI(parentPos.front(), C_IDX(n));
236  return {posLo, ADDI(posLo, C_IDX(n))};
237  }
238 };
239 
240 } // namespace
241 
242 //===----------------------------------------------------------------------===//
243 // File local helpers
244 //===----------------------------------------------------------------------===//
245 
247  OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
249  builder) {
250  TypeRange ifRetTypes = elseRet.getTypes();
251  auto ifOp = scf::IfOp::create(b, l, ifRetTypes, it.genNotEnd(b, l), true);
252 
253  b.setInsertionPointToStart(ifOp.thenBlock());
254  Value crd = it.deref(b, l);
255  scf::ValueVector ret = builder(b, l, crd);
256  YIELD(ret);
257 
258  b.setInsertionPointToStart(ifOp.elseBlock());
259  YIELD(elseRet);
260 
261  b.setInsertionPointAfter(ifOp);
262  return ifOp.getResults();
263 }
264 
265 /// Generates code to compute the *absolute* offset of the slice based on the
266 /// provide minimum coordinates in the slice.
267 /// E.g., when reducing d0 + d1 + d2, we need two slices to fully reduced the
268 /// expression, i,e, s1 = slice(T, d0), s2 = slice(s1, d1). The *absolute*
269 /// offset is the offset computed relative to the initial tensors T.
270 ///
271 /// When isNonEmpty == true, the computed offset is meaningless and should not
272 /// be used during runtime, the method generates code to return 0 currently in
273 /// that case.
274 ///
275 /// offset = minCrd >= size ? minCrd - size + 1 : 0;
277  Value size) {
278  Value geSize = CMPI(uge, minCrd, size);
279  // Compute minCrd - size + 1.
280  Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
281  // This is the absolute offset related to the actual tensor.
282  return SELECT(geSize, mms, C_IDX(0));
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // SparseIterator derived classes.
287 //===----------------------------------------------------------------------===//
288 
289 namespace {
290 
291 // The iterator that traverses a concrete sparse tensor levels. High-level
292 // abstract iterators wrap it to achieve more complex goals (such as collapsing
293 // several levels). It also holds the common storage to hold the mlir::Values
294 // for itself as well as for wrappers.
295 class ConcreteIterator : public SparseIterator {
296 protected:
297  ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
298  unsigned cursorValCnt)
299  : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
300  stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
301  assert(getCursor().size() == cursorValCnt);
302  };
303 
304 public:
305  // For LLVM-style RTTI.
306  static bool classof(const SparseIterator *from) {
307  return from->kind == IterKind::kTrivial;
308  }
309 
310  bool isBatchIterator() const override {
311  return stl.getLT().isa<LevelFormat::Batch>();
312  }
313  bool randomAccessible() const override {
314  return stl.getLT().hasDenseSemantic();
315  };
316  bool iteratableByFor() const override { return kind != IterKind::kDedup; };
317  Value upperBound(OpBuilder &b, Location l) const override {
318  return stl.getSize();
319  };
320 
321 protected:
322  const SparseTensorLevel &stl;
323  // Owner of the storage, all wrappers build on top of a concrete iterator
324  // share the same storage such that the iterator values are always
325  // synchronized.
326  SmallVector<Value> cursorValsStorage;
327 };
328 
329 class TrivialIterator : public ConcreteIterator {
330 public:
331  TrivialIterator(const SparseTensorLevel &stl)
332  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333 
334  TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335  Value posLo, Value posHi)
336  : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337  posHi(posHi) {
338  seek(posLo);
339  }
340 
341  std::string getDebugInterfacePrefix() const override {
342  return std::string("trivial<") + stl.toString() + ">";
343  }
344  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
345  return {b.getIndexType()};
346  }
347 
348  SmallVector<Value> serialize() const override {
349  SmallVector<Value> ret;
350  ret.push_back(getItPos());
351  if (randomAccessible()) {
352  // Loop high is implicit (defined by `upperBound()`) for random-access
353  // iterator, but we need to memorize posLo for linearization.
354  ret.push_back(posLo);
355  } else {
356  ret.push_back(posHi);
357  }
358  return ret;
359  };
360 
361  void deserialize(ValueRange vs) override {
362  assert(vs.size() == 2);
363  seek(vs.front());
364  if (randomAccessible())
365  posLo = vs.back();
366  else
367  posHi = vs.back();
368  };
369 
370  void genInitImpl(OpBuilder &b, Location l,
371  const SparseIterator *parent) override;
372 
373  ValuePair genForCond(OpBuilder &b, Location l) override {
374  if (randomAccessible())
375  return {deref(b, l), upperBound(b, l)};
376  return std::make_pair(getItPos(), posHi);
377  }
378 
379  Value genNotEndImpl(OpBuilder &b, Location l) override {
380  // We used the first level bound as the bound the collapsed set of levels.
381  return CMPI(ult, getItPos(), posHi);
382  }
383 
384  Value derefImpl(OpBuilder &b, Location l) override {
385  if (randomAccessible()) {
386  updateCrd(SUBI(getItPos(), posLo));
387  } else {
388  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getItPos()));
389  }
390  return getCrd();
391  };
392 
393  ValueRange forwardImpl(OpBuilder &b, Location l) override {
394  seek(ADDI(getItPos(), C_IDX(1)));
395  return getCursor();
396  }
397 
398  ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
399  Value curPos = getCursor().front();
400  Value nxPos = forward(b, l).front();
401  seek(SELECT(cond, nxPos, curPos));
402  return getCursor();
403  }
404 
405  void locateImpl(OpBuilder &b, Location l, Value crd) override {
406  assert(randomAccessible());
407  // Seek to the linearized position.
408  seek(ADDI(crd, posLo));
409  updateCrd(crd);
410  if (isBatchIterator()) {
411  // If this is a batch iterator, also update the batch coordinate.
412  assert(batchCrds.size() > lvl);
413  batchCrds[lvl] = crd;
414  }
415  }
416 
417  Value getItPos() const { return getCursor().front(); }
418  Value posLo, posHi;
419 };
420 
421 class DedupIterator : public ConcreteIterator {
422 private:
423  Value genSegmentHigh(OpBuilder &b, Location l, Value pos);
424 
425 public:
426  DedupIterator(const SparseTensorLevel &stl)
427  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428  assert(!stl.isUnique());
429  }
430 
431  DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432  Value posLo, Value posHi)
433  : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434  assert(!stl.isUnique());
435  seek({posLo, genSegmentHigh(b, l, posLo)});
436  }
437 
438  // For LLVM-style RTTI.
439  static bool classof(const SparseIterator *from) {
440  return from->kind == IterKind::kDedup;
441  }
442 
443  std::string getDebugInterfacePrefix() const override {
444  return std::string("dedup<") + stl.toString() + ">";
445  }
446  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
447  return {b.getIndexType(), b.getIndexType()};
448  }
449 
450  void genInitImpl(OpBuilder &b, Location l,
451  const SparseIterator *parent) override {
452  Value c0 = C_IDX(0);
453  ValueRange pPos = c0;
454 
455  // If the parent iterator is a batch iterator, we also start from 0 (but
456  // on a different batch).
457  if (parent && !parent->isBatchIterator())
458  pPos = parent->getCurPosition();
459 
460  Value posLo;
461  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
462  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
463 
464  seek({posLo, genSegmentHigh(b, l, posLo)});
465  }
466 
467  SmallVector<Value> serialize() const override {
468  SmallVector<Value> ret;
469  ret.append(getCursor().begin(), getCursor().end());
470  ret.push_back(posHi);
471  return ret;
472  };
473  void deserialize(ValueRange vs) override {
474  assert(vs.size() == 3);
475  seek(vs.take_front(getCursor().size()));
476  posHi = vs.back();
477  };
478 
479  Value genNotEndImpl(OpBuilder &b, Location l) override {
480  return CMPI(ult, getPos(), posHi);
481  }
482 
483  Value derefImpl(OpBuilder &b, Location l) override {
484  updateCrd(stl.peekCrdAt(b, l, getBatchCrds(), getPos()));
485  return getCrd();
486  };
487 
488  ValueRange forwardImpl(OpBuilder &b, Location l) override {
489  Value nxPos = getSegHi(); // forward the position to the next segment.
490  seek({nxPos, genSegmentHigh(b, l, nxPos)});
491  return getCursor();
492  }
493 
494  Value getPos() const { return getCursor()[0]; }
495  Value getSegHi() const { return getCursor()[1]; }
496 
497  Value posHi;
498 };
499 
500 // A util base-iterator that delegates all methods to the wrapped iterator.
501 class SimpleWrapIterator : public SparseIterator {
502 public:
503  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
504  unsigned extraCursorVal = 0)
505  : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
506 
507  void setSparseEmitStrategy(SparseEmitStrategy strategy) override {
508  wrap->setSparseEmitStrategy(strategy);
509  }
510 
511  SparseEmitStrategy getSparseEmitStrategy() const override {
512  return wrap->getSparseEmitStrategy();
513  }
514 
515  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
516  return wrap->getCursorValTypes(b);
517  }
518  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
519  bool randomAccessible() const override { return wrap->randomAccessible(); };
520  bool iteratableByFor() const override { return wrap->iteratableByFor(); };
521 
522  SmallVector<Value> serialize() const override { return wrap->serialize(); };
523  void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
524  ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
525  void genInitImpl(OpBuilder &b, Location l,
526  const SparseIterator *parent) override {
527  wrap->genInit(b, l, parent);
528  }
529  Value genNotEndImpl(OpBuilder &b, Location l) override {
530  return wrap->genNotEndImpl(b, l);
531  }
532  ValueRange forwardImpl(OpBuilder &b, Location l) override {
533  return wrap->forward(b, l);
534  };
535  Value upperBound(OpBuilder &b, Location l) const override {
536  return wrap->upperBound(b, l);
537  };
538 
539  Value derefImpl(OpBuilder &b, Location l) override {
540  return wrap->derefImpl(b, l);
541  }
542 
543  void locateImpl(OpBuilder &b, Location l, Value crd) override {
544  return wrap->locate(b, l, crd);
545  }
546 
547  SparseIterator &getWrappedIterator() const { return *wrap; }
548 
549 protected:
550  std::unique_ptr<SparseIterator> wrap;
551 };
552 
553 //
554 // A filter iterator wrapped from another iterator. The filter iterator update
555 // the wrapped iterator *in-place*.
556 //
557 class FilterIterator : public SimpleWrapIterator {
558  // Coorindate translation between crd loaded from the wrap iterator and the
559  // filter iterator.
560  Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
561  // crd = (wrapCrd - offset) / stride
562  return DIVUI(SUBI(wrapCrd, offset), stride);
563  }
564  Value toWrapCrd(OpBuilder &b, Location l, Value crd) const {
565  // wrapCrd = crd * stride + offset
566  return ADDI(MULI(crd, stride), offset);
567  }
568 
569  Value genCrdNotLegitPredicate(OpBuilder &b, Location l, Value wrapCrd);
570 
571  Value genShouldFilter(OpBuilder &b, Location l);
572 
573 public:
574  // TODO: avoid unnessary check when offset == 0 and/or when stride == 1 and/or
575  // when crd always < size.
576  FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
577  Value stride, Value size)
578  : SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
579  stride(stride), size(size) {}
580 
581  // For LLVM-style RTTI.
582  static bool classof(const SparseIterator *from) {
583  return from->kind == IterKind::kFilter;
584  }
585 
586  std::string getDebugInterfacePrefix() const override {
587  return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
588  }
589 
590  bool iteratableByFor() const override { return randomAccessible(); };
591  Value upperBound(OpBuilder &b, Location l) const override { return size; };
592 
593  void genInitImpl(OpBuilder &b, Location l,
594  const SparseIterator *parent) override {
595  wrap->genInit(b, l, parent);
596  if (!randomAccessible()) {
597  // TODO: we can skip this when stride == 1 and offset == 0, we can also
598  // use binary search here.
599  forwardIf(b, l, genShouldFilter(b, l));
600  } else {
601  // Else, locate to the slice.offset, which is the first coordinate
602  // included by the slice.
603  wrap->locate(b, l, offset);
604  }
605  }
606 
607  Value genNotEndImpl(OpBuilder &b, Location l) override;
608 
609  Value derefImpl(OpBuilder &b, Location l) override {
610  updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
611  return getCrd();
612  }
613 
614  void locateImpl(OpBuilder &b, Location l, Value crd) override {
615  assert(randomAccessible());
616  wrap->locate(b, l, toWrapCrd(b, l, crd));
617  updateCrd(crd);
618  }
619 
620  ValueRange forwardImpl(OpBuilder &b, Location l) override;
621 
622  Value offset, stride, size;
623 };
624 
625 //
626 // A pad iterator wrapped from another iterator. The pad iterator updates
627 // the wrapped iterator *in-place*.
628 //
629 class PadIterator : public SimpleWrapIterator {
630 
631 public:
632  PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
633  Value padHigh)
634  : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
635  wrap->randomAccessible() ? 1 : 0),
636  padLow(padLow), padHigh(padHigh) {}
637 
638  // For LLVM-style RTTI.
639  static bool classof(const SparseIterator *from) {
640  return from->kind == IterKind::kPad;
641  }
642 
643  std::string getDebugInterfacePrefix() const override {
644  return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
645  }
646 
647  // Returns a pair of values for *upper*, *lower* bound respectively.
648  ValuePair genForCond(OpBuilder &b, Location l) override {
649  if (randomAccessible())
650  return {getCrd(), upperBound(b, l)};
651  return wrap->genForCond(b, l);
652  }
653 
654  // For padded dense iterator, we append a `inPadZone: bool` in addition to
655  // values used by the wrapped iterator.
656  ValueRange getCurPosition() const override { return getCursor(); }
657 
658  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
659  SmallVector<Type> ret = wrap->getCursorValTypes(b);
660  // Need an extra boolean value `inPadZone` for padded dense iterator.
661  if (randomAccessible())
662  ret.push_back(b.getI1Type());
663 
664  return ret;
665  }
666 
667  // The upper bound after padding becomes `size + padLow + padHigh`.
668  Value upperBound(OpBuilder &b, Location l) const override {
669  return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
670  };
671 
672  // The pad_coord = coord + pad_lo
673  Value derefImpl(OpBuilder &b, Location l) override {
674  updateCrd(ADDI(wrap->deref(b, l), padLow));
675  return getCrd();
676  }
677 
678  void locateImpl(OpBuilder &b, Location l, Value crd) override {
679  assert(randomAccessible());
680  wrap->locate(b, l, SUBI(crd, padLow));
681 
682  // inPadZone = crd < padLow || crd >= size + padLow.
683  Value inPadLow = CMPI(ult, crd, padLow);
684  Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
685  getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
686 
687  updateCrd(crd);
688  }
689 
690  Value padLow, padHigh;
691 };
692 
693 class NonEmptySubSectIterator : public SparseIterator {
694 public:
695  using TraverseBuilder = llvm::function_ref<scf::ValueVector(
697 
698  NonEmptySubSectIterator(OpBuilder &b, Location l,
699  const SparseIterator *parent,
700  std::unique_ptr<SparseIterator> &&delegate,
701  Value subSectSz)
702  : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
703  parent(parent), delegate(std::move(delegate)),
704  tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
705  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
706  if (p == nullptr) {
707  // Extract subsections along the root level.
708  maxTupleCnt = C_IDX(1);
709  } else if (p->lvl == lvl) {
710  // Extract subsections along the same level.
711  maxTupleCnt = p->maxTupleCnt;
712  assert(false && "Not implemented.");
713  } else {
714  // Extract subsections along the previous level.
715  assert(p->lvl + 1 == lvl);
716  maxTupleCnt = MULI(p->maxTupleCnt, p->subSectSz);
717  }
718  // We don't need an extra buffer to find subsections on random-accessible
719  // levels.
720  if (randomAccessible())
721  return;
722  subSectPosBuf = allocSubSectPosBuf(b, l);
723  }
724 
725  // For LLVM-style RTTI.
726  static bool classof(const SparseIterator *from) {
727  return from->kind == IterKind::kNonEmptySubSect;
728  }
729 
730  std::string getDebugInterfacePrefix() const override {
731  return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
732  }
733  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
734  // minCrd, absolute offset, notEnd
735  return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
736  }
737 
738  // The sliced pointer buffer is organized as:
739  // [[itVal0, itVal1, ..., pNx0],
740  // [itVal0, itVal1, ..., pNx0],
741  // ...]
742  Value allocSubSectPosBuf(OpBuilder &b, Location l) {
743  return memref::AllocaOp::create(
744  b, l,
745  MemRefType::get({ShapedType::kDynamic, tupleSz + 1}, b.getIndexType()),
746  maxTupleCnt);
747  }
748 
749  void storeNxLvlStart(OpBuilder &b, Location l, Value tupleId,
750  Value start) const {
751  memref::StoreOp::create(b, l, start, subSectPosBuf,
752  ValueRange{tupleId, C_IDX(tupleSz)});
753  }
754 
755  Value loadNxLvlStart(OpBuilder &b, Location l, Value tupleId) const {
756  return memref::LoadOp::create(b, l, subSectPosBuf,
757  ValueRange{tupleId, C_IDX(tupleSz)});
758  }
759 
760  void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
761  ValueRange itVals) const {
762  assert(itVals.size() == tupleSz);
763  for (unsigned i = 0; i < tupleSz; i++) {
764  memref::StoreOp::create(b, l, itVals[i], subSectPosBuf,
765  ValueRange{tupleId, C_IDX(i)});
766  }
767  }
768 
769  SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
770  Value tupleId) const {
771  SmallVector<Value> ret;
772  for (unsigned i = 0; i < tupleSz; i++) {
773  Value v = memref::LoadOp::create(b, l, subSectPosBuf,
774  ValueRange{tupleId, C_IDX(i)});
775  ret.push_back(v);
776  }
777  return ret;
778  }
779 
780  bool isSubSectRoot() const {
781  return !parent || !llvm::isa<NonEmptySubSectIterator>(parent);
782  }
783 
784  // Generate code that inflate the current subsection tree till the current
785  // level such that every leaf node is visited.
786  ValueRange inflateSubSectTree(OpBuilder &b, Location l, ValueRange reduc,
787  TraverseBuilder builder) const;
788 
789  bool isBatchIterator() const override { return delegate->isBatchIterator(); }
790  bool randomAccessible() const override {
791  return delegate->randomAccessible();
792  };
793  bool iteratableByFor() const override { return randomAccessible(); };
794  Value upperBound(OpBuilder &b, Location l) const override {
795  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
796  Value parentUB =
797  p && p->lvl == lvl ? p->upperBound(b, l) : delegate->upperBound(b, l);
798  return ADDI(SUBI(parentUB, subSectSz), C_IDX(1));
799  };
800 
801  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
802 
803  void locateImpl(OpBuilder &b, Location l, Value crd) override {
804  Value absOff = crd;
805 
806  if (isSubSectRoot())
807  delegate->locate(b, l, absOff);
808  else
809  assert(parent->lvl + 1 == lvl);
810 
811  seek(ValueRange{absOff, absOff, C_TRUE});
812  updateCrd(crd);
813  }
814 
815  Value toSubSectCrd(OpBuilder &b, Location l, Value wrapCrd) const {
816  return SUBI(wrapCrd, getAbsOff());
817  }
818 
819  Value genNotEndImpl(OpBuilder &b, Location l) override {
820  return getNotEnd();
821  };
822 
823  Value derefImpl(OpBuilder &b, Location l) override {
824  // Use the relative offset to coiterate.
825  Value crd;
826  auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
827  if (p && p->lvl == lvl)
828  crd = SUBI(getAbsOff(), p->getAbsOff());
829  crd = getAbsOff();
830 
831  updateCrd(crd);
832  return crd;
833  };
834 
835  ValueRange forwardImpl(OpBuilder &b, Location l) override;
836 
837  Value getMinCrd() const { return subSectMeta[0]; }
838  Value getAbsOff() const { return subSectMeta[1]; }
839  Value getNotEnd() const { return subSectMeta[2]; }
840 
841  const SparseIterator *parent;
842  std::unique_ptr<SparseIterator> delegate;
843 
844  // Number of values required to serialize the wrapped iterator.
845  const unsigned tupleSz;
846  // Max number of tuples, and the actual number of tuple.
847  Value maxTupleCnt, tupleCnt;
848  // The memory used to cache the tuple serialized from the wrapped iterator.
849  Value subSectPosBuf;
850 
851  const Value subSectSz;
852 
853  // minCrd, absolute offset, notEnd
854  SmallVector<Value, 3> subSectMeta{nullptr, nullptr, nullptr};
855 };
856 
857 class SubSectIterator;
858 
859 // A wrapper that helps generating code to traverse a subsection, used
860 // by both `NonEmptySubSectIterator`and `SubSectIterator`.
861 struct SubSectIterHelper {
862  explicit SubSectIterHelper(const SubSectIterator &iter);
863  explicit SubSectIterHelper(const NonEmptySubSectIterator &subSect);
864 
865  // Delegate methods.
866  void deserializeFromTupleId(OpBuilder &b, Location l, Value tupleId);
867  void locate(OpBuilder &b, Location l, Value crd);
868  Value genNotEnd(OpBuilder &b, Location l);
869  Value deref(OpBuilder &b, Location l);
870  ValueRange forward(OpBuilder &b, Location l);
871 
872  const NonEmptySubSectIterator &subSect;
874 };
875 
876 class SubSectIterator : public SparseIterator {
877 public:
878  SubSectIterator(const NonEmptySubSectIterator &subSect,
879  const SparseIterator &parent,
880  std::unique_ptr<SparseIterator> &&wrap)
882  /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
883  subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
884  assert(subSect.tid == tid && subSect.lvl == lvl);
885  assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
886  };
887 
888  // For LLVM-style RTTI.
889  static bool classof(const SparseIterator *from) {
890  return from->kind == IterKind::kSubSect;
891  }
892 
893  std::string getDebugInterfacePrefix() const override {
894  return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
895  }
896  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
897  SmallVector<Type> ret = wrap->getCursorValTypes(b);
898  if (!randomAccessible())
899  ret.push_back(b.getIndexType()); // The extra counter.
900  return ret;
901  }
902 
903  bool isBatchIterator() const override { return wrap->isBatchIterator(); }
904  bool randomAccessible() const override { return wrap->randomAccessible(); };
905  bool iteratableByFor() const override { return randomAccessible(); };
906  Value upperBound(OpBuilder &b, Location l) const override {
907  return subSect.subSectSz;
908  }
909 
910  ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
911 
912  Value getNxLvlTupleId(OpBuilder &b, Location l) const {
913  if (randomAccessible()) {
914  return ADDI(getCrd(), nxLvlTupleStart);
915  };
916  return ADDI(getCursor().back(), nxLvlTupleStart);
917  }
918 
919  void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
920  if (randomAccessible()) {
921  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
922  assert(p->lvl + 1 == lvl);
923  wrap->genInit(b, l, p);
924  // Linearize the dense subsection index.
925  nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
926  } else {
927  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
928  wrap->deserialize(subSect.delegate->serialize());
929  nxLvlTupleStart = C_IDX(0);
930  }
931  return;
932  }
933  assert(!randomAccessible());
934  assert(getCursor().size() == wrap->getCursor().size() + 1);
935  // Extra counter that counts the number of actually visited coordinates in
936  // the sparse subsection.
937  getMutCursorVals().back() = C_IDX(0);
938  Value tupleId;
939  if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
940  assert(p->lvl + 1 == lvl);
941  tupleId = p->getNxLvlTupleId(b, l);
942  } else {
943  assert(subSect.lvl == lvl && subSect.isSubSectRoot());
944  tupleId = C_IDX(0);
945  }
946  nxLvlTupleStart = subSect.loadNxLvlStart(b, l, tupleId);
947  helper.deserializeFromTupleId(b, l, tupleId);
948  }
949 
950  void locateImpl(OpBuilder &b, Location l, Value crd) override {
951  helper.locate(b, l, crd);
952  updateCrd(crd);
953  }
954 
955  Value genNotEndImpl(OpBuilder &b, Location l) override {
956  return helper.genNotEnd(b, l);
957  }
958 
959  Value derefImpl(OpBuilder &b, Location l) override {
960  Value crd = helper.deref(b, l);
961  updateCrd(crd);
962  return crd;
963  };
964 
965  ValueRange forwardImpl(OpBuilder &b, Location l) override {
966  helper.forward(b, l);
967  assert(!randomAccessible());
968  assert(getCursor().size() == wrap->getCursor().size() + 1);
969  getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
970  return getCursor();
971  };
972 
973  Value nxLvlTupleStart;
974 
975  const NonEmptySubSectIterator &subSect;
976  std::unique_ptr<SparseIterator> wrap;
977  const SparseIterator &parent;
978 
979  SubSectIterHelper helper;
980 };
981 
982 } // namespace
983 
984 //===----------------------------------------------------------------------===//
985 // SparseIterator derived classes implementation.
986 //===----------------------------------------------------------------------===//
987 
989  const SparseIterator *p) {
991  std::string prefix = getDebugInterfacePrefix();
992  Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
993  getCursorValTypes(b));
994  seek(begin->getResults());
995  return;
996  }
997  // Inherent batch coordinates from parents.
998  if (p)
999  inherentBatch(*p);
1000  // TODO: support lowering to function call.
1001  return genInitImpl(b, l, p);
1002 }
1003 
1006  std::string prefix = getDebugInterfacePrefix();
1007  Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
1008  getCursor(), b.getI1Type());
1009  return notEnd->getResult(0);
1010  }
1011  // TODO: support lowering to function call.
1012  return genNotEndImpl(b, l);
1013 }
1014 
1017  std::string prefix = getDebugInterfacePrefix();
1018  SmallVector<Value> args = getCursor();
1019  args.push_back(crd);
1020  Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
1021  getCursorValTypes(b));
1022  seek(locate->getResults());
1023  updateCrd(crd);
1024  return;
1025  }
1026  return locateImpl(b, l, crd);
1027 }
1028 
1031  std::string prefix = getDebugInterfacePrefix();
1032  SmallVector<Value> args = getCursor();
1033  Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
1034  getCursor(), b.getIndexType());
1035  updateCrd(deref->getResult(0));
1036  return getCrd();
1037  }
1038  return derefImpl(b, l);
1039 }
1040 
1042  assert(!randomAccessible());
1044  std::string prefix = getDebugInterfacePrefix();
1045  Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
1047  seek(next->getResults());
1048  return getCursor();
1049  }
1050  return forwardImpl(b, l);
1051 }
1052 
1054  auto ifOp = scf::IfOp::create(b, l, getCursor().getTypes(), cond, true);
1055  // Generate else branch first, otherwise iterator values will be updated by
1056  // `forward()`.
1057  b.setInsertionPointToStart(ifOp.elseBlock());
1058  YIELD(getCursor());
1059 
1060  b.setInsertionPointToStart(ifOp.thenBlock());
1061  YIELD(forward(b, l));
1062 
1063  b.setInsertionPointAfter(ifOp);
1064  seek(ifOp.getResults());
1065  return getCursor();
1066 }
1067 
1068 Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
1069  auto whileOp = scf::WhileOp::create(
1070  b, l, pos.getType(), pos,
1071  /*beforeBuilder=*/
1072  [this, pos](OpBuilder &b, Location l, ValueRange ivs) {
1073  Value inBound = CMPI(ult, ivs.front(), posHi);
1074  auto ifInBound = scf::IfOp::create(b, l, b.getI1Type(), inBound, true);
1075  {
1076  OpBuilder::InsertionGuard guard(b);
1077  // If in bound, load the next coordinates and check duplication.
1078  b.setInsertionPointToStart(ifInBound.thenBlock());
1079  Value headCrd = stl.peekCrdAt(b, l, getBatchCrds(), pos);
1080  Value tailCrd = stl.peekCrdAt(b, l, getBatchCrds(), ivs.front());
1081  Value isDup = CMPI(eq, headCrd, tailCrd);
1082  YIELD(isDup);
1083  // Else, the position is out of bound, yield false.
1084  b.setInsertionPointToStart(ifInBound.elseBlock());
1085  YIELD(constantI1(b, l, false));
1086  }
1087  scf::ConditionOp::create(b, l, ifInBound.getResults()[0], ivs);
1088  },
1089  /*afterBuilder=*/
1090  [](OpBuilder &b, Location l, ValueRange ivs) {
1091  Value nxPos = ADDI(ivs[0], C_IDX(1));
1092  YIELD(nxPos);
1093  });
1094  // Return the segment high.
1095  return whileOp.getResult(0);
1096 }
1097 
1098 Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
1099  Value wrapCrd) {
1100  Value crd = fromWrapCrd(b, l, wrapCrd);
1101  // Test whether the coordinate is on stride.
1102  Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
1103  // Test wrapCrd < offset
1104  notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
1105  // Test crd >= length
1106  notlegit = ORI(CMPI(uge, crd, size), notlegit);
1107  return notlegit;
1108 }
1109 
1110 Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
1111  auto r = genWhenInBound(
1112  b, l, *wrap, C_FALSE,
1113  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1114  Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
1115  return {notLegit};
1116  });
1117  return llvm::getSingleElement(r);
1118 }
1119 
1120 Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
1121  assert(!wrap->randomAccessible());
1122  auto r = genWhenInBound(
1123  b, l, *wrap, C_FALSE,
1124  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1125  Value crd = fromWrapCrd(b, l, wrapCrd);
1126  // crd < size
1127  return {CMPI(ult, crd, size)};
1128  });
1129  return llvm::getSingleElement(r);
1130 }
1131 
1132 ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
1133  assert(!randomAccessible());
1134  // Generates
1135  //
1136  // bool isFirst = true;
1137  // while !it.end() && (!legit(*it) || isFirst)
1138  // wrap ++;
1139  // isFirst = false;
1140  //
1141  // We do not hoist the first `wrap++` outside the loop but use a `isFirst`
1142  // flag here because `wrap++` might have a complex implementation (e.g., to
1143  // forward a subsection).
1144  Value isFirst = constantI1(b, l, true);
1145 
1146  SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
1147  whileArgs.push_back(isFirst);
1148  auto whileOp = scf::WhileOp::create(
1149  b, l, ValueRange(whileArgs).getTypes(), whileArgs,
1150  /*beforeBuilder=*/
1151  [this](OpBuilder &b, Location l, ValueRange ivs) {
1152  ValueRange isFirst = linkNewScope(ivs);
1153  scf::ValueVector cont =
1154  genWhenInBound(b, l, *wrap, C_FALSE,
1155  [this, isFirst](OpBuilder &b, Location l,
1156  Value wrapCrd) -> scf::ValueVector {
1157  // crd < size && !legit();
1158  Value notLegit =
1159  genCrdNotLegitPredicate(b, l, wrapCrd);
1160  Value crd = fromWrapCrd(b, l, wrapCrd);
1161  Value ret = ANDI(CMPI(ult, crd, size), notLegit);
1162  ret = ORI(ret, llvm::getSingleElement(isFirst));
1163  return {ret};
1164  });
1165  scf::ConditionOp::create(b, l, cont.front(), ivs);
1166  },
1167  /*afterBuilder=*/
1168  [this](OpBuilder &b, Location l, ValueRange ivs) {
1169  linkNewScope(ivs);
1170  wrap->forward(b, l);
1171  SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
1172  yieldVals.push_back(constantI1(b, l, false));
1173  YIELD(yieldVals);
1174  });
1175 
1176  b.setInsertionPointAfter(whileOp);
1177  linkNewScope(whileOp.getResults());
1178  return getCursor();
1179 }
1180 
1181 SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
1182  : subSect(subSect), wrap(*subSect.delegate) {}
1183 
1184 SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
1185  : subSect(iter.subSect), wrap(*iter.wrap) {}
1186 
1187 void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
1188  Value tupleId) {
1189  assert(!subSect.randomAccessible());
1190  wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
1191 }
1192 
1193 void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
1194  Value absCrd = ADDI(crd, subSect.getAbsOff());
1195  wrap.locate(b, l, absCrd);
1196 }
1197 
1198 Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
1199  assert(!wrap.randomAccessible());
1200  auto r = genWhenInBound(
1201  b, l, wrap, C_FALSE,
1202  [this](OpBuilder &b, Location l, Value wrapCrd) -> scf::ValueVector {
1203  Value crd = SUBI(wrapCrd, subSect.getAbsOff());
1204  // crd < size
1205  return {CMPI(ult, crd, subSect.subSectSz)};
1206  });
1207  return llvm::getSingleElement(r);
1208 }
1209 
1210 Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
1211  Value wrapCrd = wrap.deref(b, l);
1212  Value crd = subSect.toSubSectCrd(b, l, wrapCrd);
1213  return crd;
1214 }
1215 
1216 ValueRange SubSectIterHelper::forward(OpBuilder &b, Location l) {
1217  return wrap.forward(b, l);
1218 }
1219 
1220 ValueRange NonEmptySubSectIterator::inflateSubSectTree(
1221  OpBuilder &b, Location l, ValueRange reduc, TraverseBuilder builder) const {
1222  // Set up the helper to help traverse a sparse subsection.
1223  SubSectIterHelper helper(*this);
1224  if (!randomAccessible()) {
1225  // The subsection tree have been expanded till the level and cached,
1226  // traverse all the leaves and expanded to the next level.
1227  SmallVector<Value> iterArgs;
1228  iterArgs.push_back(C_IDX(0));
1229  iterArgs.append(reduc.begin(), reduc.end());
1230  auto forEachLeaf = scf::ForOp::create(
1231  b, l, /*lb=*/C_IDX(0), /*ub=*/tupleCnt, /*step=*/C_IDX(1), iterArgs,
1232  [&helper, &builder](OpBuilder &b, Location l, Value tupleId,
1233  ValueRange iterArgs) {
1234  // Deserialize the iterator at the cached position (tupleId).
1235  helper.deserializeFromTupleId(b, l, tupleId);
1236 
1237  Value cnt = iterArgs.front();
1238  // Record the number of leaf nodes included in the subsection.
1239  // The number indicates the starting tupleId for the next level that
1240  // is corresponding to the current node.
1241  helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
1242 
1243  SmallVector<Value> whileArgs(helper.wrap.getCursor());
1244  whileArgs.append(iterArgs.begin(), iterArgs.end());
1245 
1246  auto whileOp = scf::WhileOp::create(
1247  b, l, ValueRange(whileArgs).getTypes(), whileArgs,
1248  /*beforeBuilder=*/
1249  [&helper](OpBuilder &b, Location l, ValueRange ivs) {
1250  helper.wrap.linkNewScope(ivs);
1251  scf::ConditionOp::create(b, l, helper.genNotEnd(b, l), ivs);
1252  },
1253  /*afterBuilder=*/
1254  [&helper, &builder](OpBuilder &b, Location l, ValueRange ivs) {
1255  ValueRange remIter = helper.wrap.linkNewScope(ivs);
1256  Value cnt = remIter.front();
1257  ValueRange userIter = remIter.drop_front();
1258  scf::ValueVector userNx = builder(b, l, &helper.wrap, userIter);
1259 
1260  SmallVector<Value> nxIter = helper.forward(b, l);
1261  nxIter.push_back(ADDI(cnt, C_IDX(1)));
1262  nxIter.append(userNx.begin(), userNx.end());
1263  YIELD(nxIter);
1264  });
1265  ValueRange res = helper.wrap.linkNewScope(whileOp.getResults());
1266  YIELD(res);
1267  });
1268  return forEachLeaf.getResults().drop_front();
1269  }
1270 
1271  assert(randomAccessible());
1272  // Helper lambda that traverse the current dense subsection range.
1273  auto visitDenseSubSect = [&, this](OpBuilder &b, Location l,
1274  const SparseIterator *parent,
1275  ValueRange reduc) {
1276  assert(!parent || parent->lvl + 1 == lvl);
1277  delegate->genInit(b, l, parent);
1278  auto forOp = scf::ForOp::create(
1279  b, l, /*lb=*/C_IDX(0), /*ub=*/subSectSz, /*step=*/C_IDX(1), reduc,
1280  [&](OpBuilder &b, Location l, Value crd, ValueRange iterArgs) {
1281  helper.locate(b, l, crd);
1282  scf::ValueVector nx = builder(b, l, &helper.wrap, iterArgs);
1283  YIELD(nx);
1284  });
1285  return forOp.getResults();
1286  };
1287 
1288  if (isSubSectRoot()) {
1289  return visitDenseSubSect(b, l, parent, reduc);
1290  }
1291  // Else, this is not the root, recurse until root.
1292  auto *p = llvm::cast<NonEmptySubSectIterator>(parent);
1293  assert(p->lvl + 1 == lvl);
1294  return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
1295 }
1296 
1297 void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1298  const SparseIterator *parent) {
1299 
1300  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1301  batchCrds.resize(stl.lvl + 1, nullptr);
1302 
1303  Value c0 = C_IDX(0);
1304  ValueRange pPos = c0;
1305  Value inPadZone = nullptr;
1306  // If the parent iterator is a batch iterator, we also start from 0 (but
1307  // on a different batch).
1308  if (parent && !parent->isBatchIterator()) {
1309  pPos = parent->getCurPosition();
1310  if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1311  // A padded dense iterator create "sparse" padded zone, which need to be
1312  // handled specially.
1313  inPadZone = pPos.back();
1314  pPos = pPos.drop_back();
1315  }
1316  }
1317 
1318  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1319  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1320  // Seek to the lowest position.
1321  seek(posLo);
1322 }
1323 
1324 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
1325  const SparseIterator *) {
1326  Value c0 = C_IDX(0);
1327  if (!isSubSectRoot()) {
1328  assert(parent->lvl + 1 == lvl);
1329  if (randomAccessible()) {
1330  // We can not call wrap->genInit() here to initialize the wrapped
1331  // iterator, because the parent of the curent iterator is still
1332  // unresolved.
1333  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1334  return;
1335  }
1336 
1337  auto *p = cast<NonEmptySubSectIterator>(parent);
1338  SmallVector<Value, 3> reduc = {
1339  C_IDX(-1), // minCrd (max signless integer)
1340  c0, // tupleId
1341  };
1342 
1343  // Expand the subsection tree from the parent level to the current level.
1344  ValueRange result = p->inflateSubSectTree(
1345  b, l, reduc,
1346  [this](OpBuilder &b, Location l, const SparseIterator *parent,
1347  ValueRange reduc) -> scf::ValueVector {
1348  assert(parent->lvl + 1 == lvl && reduc.size() == 2);
1349  Value minCrd = reduc.front();
1350  Value tupleId = reduc.back();
1351 
1352  // Initialize the subsection range.
1353  SubSectIterHelper helper(*this);
1354  helper.wrap.genInit(b, l, parent);
1355 
1356  // Update minCrd.
1357  minCrd = genWhenInBound(b, l, helper.wrap, minCrd,
1358  [minCrd](OpBuilder &b, Location l,
1359  Value crd) -> scf::ValueVector {
1360  Value min = MINUI(crd, minCrd);
1361  return {min};
1362  })
1363  .front();
1364 
1365  // Cache the sparse range.
1366  storeCursorVals(b, l, tupleId, helper.wrap.serialize());
1367  tupleId = ADDI(tupleId, C_IDX(1));
1368  return {minCrd, tupleId};
1369  });
1370  assert(result.size() == 2);
1371  tupleCnt = result.back();
1372 
1373  Value minCrd = result.front();
1374  Value absOff = offsetFromMinCrd(b, l, minCrd, subSectSz);
1375  Value notEnd = CMPI(ne, minCrd, C_IDX(-1));
1376  seek({minCrd, absOff, notEnd});
1377  return;
1378  }
1379 
1380  // This is the root level of the subsection, which means that it is resolved
1381  // to one node.
1382  assert(isSubSectRoot());
1383 
1384  // Initialize the position, the position marks the *lower bound* of the
1385  // subRange. The higher bound is determined by the size of the subsection.
1386  delegate->genInit(b, l, parent);
1387  if (randomAccessible()) {
1388  seek({/*minCrd=*/c0, /*offset=*/c0, /*notEnd=*/C_TRUE});
1389  return;
1390  }
1391 
1392  // Only have one root node.
1393  tupleCnt = C_IDX(1);
1394  // Cache the sparse range.
1395  storeCursorVals(b, l, c0, delegate->serialize());
1396  SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
1397  auto meta = genWhenInBound(
1398  b, l, *delegate, elseRet,
1399  [this](OpBuilder &b, Location l, Value crd) -> scf::ValueVector {
1400  Value offset = offsetFromMinCrd(b, l, crd, subSectSz);
1401  return {crd, offset, C_TRUE};
1402  });
1403 
1404  seek(meta);
1405 }
1406 
1407 ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
1408  assert(!randomAccessible());
1409  Value c0 = C_IDX(0), c1 = C_IDX(1);
1410  // Forward to the next non empty slice by generating
1411  //
1412  // if (minCrd > offset) {
1413  // offset += 1
1414  // } else {
1415  // minCrd = nextMinInSlice();
1416  // offset = minCrd - size + 1;
1417  // }
1418  //
1419  // if (offset + size > parents.size)
1420  // isNonEmpty = false;
1421  Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
1422  auto ifOp = scf::IfOp::create(b, l, getCursor().getTypes(), fastPathP, true);
1423  {
1424  OpBuilder::InsertionGuard guard(b);
1425  // Take the fast path
1426  // if (minCrd > offset)
1427  // offset += 1
1428  b.setInsertionPointToStart(&ifOp.getThenRegion().front());
1429  Value nxOffset = ADDI(getAbsOff(), c1);
1430  YIELD((ValueRange{getMinCrd(), nxOffset, getNotEnd()}));
1431 
1432  // else /*minCrd == offset*/ {
1433  // for (i = 0; i < tupleCnt; i++) {
1434  // wrap->deserialize(pos[i]);
1435  // minCrd=min(minCrd, *wrap);
1436  // }
1437  // offset = minCrd - size + 1;
1438  // }
1439  b.setInsertionPointToStart(&ifOp.getElseRegion().front());
1440  SmallVector<Value, 2> loopArgs{C_IDX(-1), // nextMinCrd
1441  C_FALSE}; // isNotEnd
1442  auto loopNest = scf::buildLoopNest(
1443  b, l, c0, tupleCnt, c1, loopArgs,
1444  [this](OpBuilder &b, Location l, ValueRange ivs,
1445  ValueRange iterArgs) -> scf::ValueVector {
1446  Value tupleId = ivs.front();
1447  SubSectIterHelper helper(*this);
1448  helper.deserializeFromTupleId(b, l, tupleId);
1449 
1450  return genWhenInBound(
1451  b, l, *delegate, /*elseRet=*/iterArgs,
1452  [this, iterArgs, tupleId](OpBuilder &b, Location l,
1453  Value crd) -> scf::ValueVector {
1454  // if coord == minCrd
1455  // wrap->forward();
1456  Value isMin = CMPI(eq, crd, getMinCrd());
1457  delegate->forwardIf(b, l, isMin);
1458  // Update the forwarded iterator values if needed.
1459  auto ifIsMin = scf::IfOp::create(b, l, isMin, false);
1460  b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
1461  storeCursorVals(b, l, tupleId, delegate->serialize());
1462  b.setInsertionPointAfter(ifIsMin);
1463  // if (!wrap.end())
1464  // yield(min(nxMinCrd, *wrap), true)
1465  Value nxMin = iterArgs[0];
1466  return genWhenInBound(b, l, *delegate, /*elseRet=*/iterArgs,
1467  [nxMin](OpBuilder &b, Location l,
1468  Value crd) -> scf::ValueVector {
1469  Value nx = arith::MinUIOp::create(
1470  b, l, crd, nxMin);
1471  return {nx, C_TRUE};
1472  });
1473  });
1474  });
1475 
1476  scf::ForOp forOp = loopNest.loops.front();
1477  b.setInsertionPointAfter(forOp);
1478 
1479  Value nxMinCrd = forOp.getResult(0);
1480  Value nxNotEnd = forOp.getResult(1);
1481  Value nxAbsOff = offsetFromMinCrd(b, l, nxMinCrd, subSectSz);
1482  YIELD((ValueRange{nxMinCrd, nxAbsOff, nxNotEnd}));
1483  }
1484 
1485  Value nxMinCrd = ifOp.getResult(0);
1486  Value nxAbsOff = ifOp.getResult(1);
1487  Value nxNotEnd = ifOp.getResult(2);
1488 
1489  // We should at least forward the offset by one.
1490  Value minAbsOff = ADDI(getAbsOff(), c1);
1491  nxAbsOff = arith::MaxUIOp::create(b, l, minAbsOff, nxAbsOff);
1492 
1493  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1494  // The coordinate should not exceeds the space upper bound.
1495  Value crd = deref(b, l);
1496  nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
1497 
1498  seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
1499  return getCursor();
1500 }
1501 
1502 //===----------------------------------------------------------------------===//
1503 // SparseIterationSpace Implementation
1504 //===----------------------------------------------------------------------===//
1505 
1507  Location l, OpBuilder &b, Value t, unsigned tid,
1508  std::pair<Level, Level> lvlRange, ValueRange parentPos)
1509  : lvls() {
1510  auto [lvlLo, lvlHi] = lvlRange;
1511 
1512  Value c0 = C_IDX(0);
1513  if (parentPos.empty())
1514  parentPos = c0;
1515 
1516  for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
1517  lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
1518 
1519  bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
1520  for (auto &lvl : getLvlRef().drop_front())
1521  bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
1522 }
1523 
1525  IterSpaceType dstTp, ValueRange values, unsigned int tid) {
1526  // Reconstruct every sparse tensor level.
1527  SparseIterationSpace space;
1528  for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
1529  unsigned bufferCnt = 0;
1530  if (lt.isWithPosLT())
1531  bufferCnt++;
1532  if (lt.isWithCrdLT())
1533  bufferCnt++;
1534  // Sparse tensor buffers.
1535  ValueRange buffers = values.take_front(bufferCnt);
1536  values = values.drop_front(bufferCnt);
1537 
1538  // Level size.
1539  Value sz = values.front();
1540  values = values.drop_front();
1541  space.lvls.push_back(
1542  makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
1543  }
1544  // Two bounds.
1545  space.bound = std::make_pair(values[0], values[1]);
1546  values = values.drop_front(2);
1547 
1548  // Must have consumed all values.
1549  assert(values.empty());
1550  return space;
1551 }
1552 
1553 std::unique_ptr<SparseIterator>
1555  return makeSimpleIterator(b, l, *this);
1556 }
1557 
1558 //===----------------------------------------------------------------------===//
1559 // SparseIterator factory functions.
1560 //===----------------------------------------------------------------------===//
1561 
1562 /// Helper function to create a TensorLevel object from given `tensor`.
1563 std::unique_ptr<SparseTensorLevel>
1565  unsigned t, Level l) {
1566  assert(lt.getNumBuffer() == b.size());
1567  switch (lt.getLvlFmt()) {
1568  case LevelFormat::Dense:
1569  return std::make_unique<DenseLevel>(t, l, sz);
1570  case LevelFormat::Batch:
1571  return std::make_unique<BatchLevel>(t, l, sz);
1573  return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
1575  return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
1577  return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
1578  case LevelFormat::NOutOfM:
1579  return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
1580  case LevelFormat::Undef:
1581  llvm_unreachable("undefined level format");
1582  }
1583  llvm_unreachable("unrecognizable level format");
1584 }
1585 
1586 std::unique_ptr<SparseTensorLevel>
1588  unsigned tid, Level lvl) {
1589  auto stt = getSparseTensorType(t);
1590 
1591  LevelType lt = stt.getLvlType(lvl);
1592  Value sz = stt.hasEncoding()
1593  ? LvlOp::create(b, l, t, lvl).getResult()
1594  : tensor::DimOp::create(b, l, t, lvl).getResult();
1595 
1596  SmallVector<Value, 2> buffers;
1597  if (lt.isWithPosLT()) {
1598  Value pos = ToPositionsOp::create(b, l, t, lvl);
1599  buffers.push_back(pos);
1600  }
1601  if (lt.isWithCrdLT()) {
1602  Value pos = ToCoordinatesOp::create(b, l, t, lvl);
1603  buffers.push_back(pos);
1604  }
1605  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
1606 }
1607 
1608 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1609 sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1610  SparseEmitStrategy strategy) {
1611  auto stl = std::make_unique<BatchLevel>(tid, lvl, sz);
1612  auto it = std::make_unique<TrivialIterator>(*stl);
1613  it->setSparseEmitStrategy(strategy);
1614  return std::make_pair(std::move(stl), std::move(it));
1615 }
1616 
1617 std::unique_ptr<SparseIterator>
1619  const SparseIterationSpace &iterSpace) {
1620  // assert(iterSpace.getSpaceDim() == 1);
1621  std::unique_ptr<SparseIterator> ret;
1622  if (!iterSpace.isUnique()) {
1623  // We always dedupliate the non-unique level, but we should optimize it away
1624  // if possible.
1625  ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1626  iterSpace.getBoundLo(),
1627  iterSpace.getBoundHi());
1628  } else {
1629  ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1630  iterSpace.getBoundLo(),
1631  iterSpace.getBoundHi());
1632  }
1633  ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1634  return ret;
1635 }
1636 
1637 std::unique_ptr<SparseIterator>
1639  SparseEmitStrategy strategy) {
1640  std::unique_ptr<SparseIterator> ret;
1641  if (!isUniqueLT(stl.getLT())) {
1642  // We always dedupliate the non-unique level, but we should optimize it away
1643  // if possible.
1644  ret = std::make_unique<DedupIterator>(stl);
1645  } else {
1646  ret = std::make_unique<TrivialIterator>(stl);
1647  }
1648  ret->setSparseEmitStrategy(strategy);
1649  return ret;
1650 }
1651 
1652 std::unique_ptr<SparseIterator>
1653 sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1654  Value offset, Value stride, Value size,
1655  SparseEmitStrategy strategy) {
1656 
1657  auto ret =
1658  std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1659  ret->setSparseEmitStrategy(strategy);
1660  return ret;
1661 }
1662 
1663 std::unique_ptr<SparseIterator>
1664 sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1665  Value padLow, Value padHigh,
1666  SparseEmitStrategy strategy) {
1667  auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1668  ret->setSparseEmitStrategy(strategy);
1669  return ret;
1670 }
1671 
1673  auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1674  if (filter)
1675  return &filter->getWrappedIterator();
1676  return it;
1677 }
1678 
1679 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1680  OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1681  std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1682  SparseEmitStrategy strategy) {
1683 
1684  // Try unwrap the NonEmptySubSectIterator from a filter parent.
1685  parent = tryUnwrapFilter(parent);
1686  std::unique_ptr<SparseIterator> it =
1687  std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1688  std::move(delegate), size);
1689 
1690  if (stride != 1) {
1691  // TODO: We can safely skip bound checking on sparse levels, but for dense
1692  // iteration space, we need the bound to infer the dense loop range.
1693  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1694  C_IDX(stride), /*size=*/loopBound);
1695  }
1696  it->setSparseEmitStrategy(strategy);
1697  return it;
1698 }
1699 
1700 std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1701  OpBuilder &b, Location l, const SparseIterator &subSectIter,
1702  const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1703  Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
1704 
1705  // This must be a subsection iterator or a filtered subsection iterator.
1706  auto &subSect =
1707  llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1708 
1709  std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1710  subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1711 
1712  if (stride != 1) {
1713  it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1714  C_IDX(stride), /*size=*/loopBound);
1715  }
1716  it->setSparseEmitStrategy(strategy);
1717  return it;
1718 }
1719 
1720 #undef CMPI
1721 #undef C_IDX
1722 #undef YIELD
1723 #undef ADDI
1724 #undef ANDI
1725 #undef SUBI
1726 #undef MULI
1727 #undef SELECT
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
static bool isUnique(It begin, It end)
Definition: ShardOps.cpp:161
#define SELECT(c, lhs, rhs)
#define C_FALSE
#define SUBI(lhs, rhs)
#define MULI(lhs, rhs)
#define C_IDX(v)
#define ANDI(lhs, rhs)
std::tuple< Value, Value, Value > ValueTuple
#define C_TRUE
static scf::ValueVector genWhenInBound(OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet, llvm::function_ref< scf::ValueVector(OpBuilder &, Location, Value)> builder)
#define YIELD(vs)
static const SparseIterator * tryUnwrapFilter(const SparseIterator *it)
#define ORI(lhs, rhs)
#define DIVUI(lhs, rhs)
std::pair< Value, Value > ValuePair
#define CMPI(p, lhs, rhs)
#define ADDI(lhs, rhs)
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd, Value size)
Generates code to compute the absolute offset of the slice based on the provide minimum coordinates i...
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
result_range getResults()
Definition: Operation.h:415
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
const SparseTensorLevel & getLastLvl() const
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value derefImpl(OpBuilder &b, Location l)=0
virtual void setSparseEmitStrategy(SparseEmitStrategy strategy)
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual SparseEmitStrategy getSparseEmitStrategy() const
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
virtual std::string getDebugInterfacePrefix() const =0
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:837
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
bool isUniqueLT(LevelType lt)
Definition: Enums.h:428
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
uint64_t getN(LevelType lt)
Definition: Enums.h:442
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
LogicalResult serialize(ModuleOp moduleOp, SmallVectorImpl< uint32_t > &binary, const SerializationOptions &options={})
Serializes the given SPIR-V moduleOp and writes to binary.
OwningOpRef< spirv::ModuleOp > deserialize(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options={})
Deserializes the given SPIR-V binary module and creates a MLIR ModuleOp in the given context.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:52
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LoopVector loops
Definition: SCF.h:67
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr unsigned getNumBuffer() const
Definition: Enums.h:360
constexpr bool hasDenseSemantic() const
Check if the LevelType is considered to be dense-like.
Definition: Enums.h:343
constexpr LevelFormat getLvlFmt() const
Get the LevelFormat of the LevelType.
Definition: Enums.h:320
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
constexpr bool isWithPosLT() const
Check if the LevelType needs positions array.
Definition: Enums.h:348
constexpr bool isWithCrdLT() const
Check if the LevelType needs coordinates array.
Definition: Enums.h:354