MLIR  19.0.0git
SparseTensorIterator.h
Go to the documentation of this file.
1 //===- SparseTensorIterator.h ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
11 
14 
15 namespace mlir {
16 namespace sparse_tensor {
17 
18 /// The base class for all types of sparse tensor levels. It provides interfaces
19 /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
20 /// `peekCrdAt`).
23  SparseTensorLevel(const SparseTensorLevel &) = delete;
24  SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
25  SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
26 
27 public:
28  virtual ~SparseTensorLevel() = default;
29 
30  std::string toString() const {
31  return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
32  std::to_string(lvl) + "]";
33  }
34 
35  virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
36  Value iv) const = 0;
37 
38  /// Peeks the lower and upper bound to *fully* traverse the level with
39  /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
40  /// that the immediate parent level is current at. Returns a pair of values
41  /// for *posLo* and *loopHi* respectively.
42  ///
43  /// For a dense level, the *posLo* is the linearized position at beginning,
44  /// while *loopHi* is the largest *coordinate*, it also implies that the
45  /// smallest *coordinate* to start the loop is 0.
46  ///
47  /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
48  /// to load coordinate from the coordinate buffer.
49  virtual std::pair<Value, Value>
51  ValueRange parentPos, Value inPadZone = nullptr) const = 0;
52 
53  Level getLevel() const { return lvl; }
54  LevelType getLT() const { return lt; }
55  Value getSize() const { return lvlSize; }
56  virtual ValueRange getLvlBuffers() const = 0;
57 
58  //
59  // Level properties
60  //
61  bool isUnique() const { return isUniqueLT(lt); }
62 
63 protected:
65  : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
66 
67 public:
68  const unsigned tid, lvl;
69  const LevelType lt;
70  const Value lvlSize;
71 };
72 
73 enum class IterKind : uint8_t {
74  kTrivial,
75  kDedup,
76  kSubSect,
78  kFilter,
79  kPad,
80 };
81 
82 /// Helper class that generates loop conditions, etc, to traverse a
83 /// sparse tensor level.
85  SparseIterator(SparseIterator &&) = delete;
86  SparseIterator(const SparseIterator &) = delete;
87  SparseIterator &operator=(SparseIterator &&) = delete;
88  SparseIterator &operator=(const SparseIterator &) = delete;
89 
90 protected:
91  SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
92  unsigned cursorValsCnt,
93  SmallVectorImpl<Value> &cursorValStorage)
94  : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
95  cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){};
96 
97  SparseIterator(IterKind kind, unsigned cursorValsCnt,
98  SmallVectorImpl<Value> &cursorValStorage,
99  const SparseIterator &delegate)
100  : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
101  cursorValStorage){};
102 
104  unsigned extraCursorCnt = 0)
106  extraCursorCnt + wrap.cursorValsCnt,
107  wrap.cursorValsStorageRef) {
108  assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
109  cursorValsStorageRef.append(extraCursorCnt, nullptr);
110  assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
111  };
112 
113 public:
114  virtual ~SparseIterator() = default;
115 
117  emitStrategy = strategy;
118  }
119 
120  virtual std::string getDebugInterfacePrefix() const = 0;
122 
123  Value getCrd() const { return crd; }
124  ValueRange getBatchCrds() const { return batchCrds; }
126  return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
127  };
128 
129  // Sets the iterate to the specified position.
130  void seek(ValueRange vals) {
131  assert(vals.size() == cursorValsCnt);
132  std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
133  // Now that the iterator is re-positioned, the coordinate becomes invalid.
134  crd = nullptr;
135  }
136 
137  //
138  // Iterator properties.
139  //
140 
141  // Whether the iterator is a iterator over a batch level.
142  virtual bool isBatchIterator() const = 0;
143 
144  // Whether the iterator support random access (i.e., support look up by
145  // *coordinate*). A random access iterator must also traverses a dense space.
146  virtual bool randomAccessible() const = 0;
147 
148  // Whether the iterator can simply traversed by a for loop.
149  virtual bool iteratableByFor() const { return false; };
150 
151  // Get the upper bound of the sparse space that the iterator might visited. A
152  // sparse space is a subset of a dense space [0, bound), this function returns
153  // *bound*.
154  virtual Value upperBound(OpBuilder &b, Location l) const = 0;
155 
156  // Serializes and deserializes the current status to/from a set of values. The
157  // ValueRange should contain values that are sufficient to recover the current
158  // iterating postion (i.e., itVals) as well as loop bound.
159  //
160  // Not every type of iterator supports the operations, e.g., non-empty
161  // subsection iterator does not because the the number of non-empty
162  // subsections can not be determined easily.
163  //
164  // NOTE: All the values should have index type.
165  virtual SmallVector<Value> serialize() const {
166  llvm_unreachable("unsupported");
167  };
168  virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
169 
170  //
171  // Core functions.
172  //
173 
174  // Initializes the iterator according to the parent iterator's state.
175  void genInit(OpBuilder &b, Location l, const SparseIterator *p);
176 
177  // Forwards the iterator to the next element.
179 
180  // Locate the iterator to the position specified by *crd*, this can only
181  // be done on an iterator that supports randm access.
182  void locate(OpBuilder &b, Location l, Value crd);
183 
184  // Returns a boolean value that equals `!it.end()`
186 
187  // Dereferences the iterator, loads the coordinate at the current position.
188  //
189  // The method assumes that the iterator is not currently exhausted (i.e.,
190  // it != it.end()).
191  Value deref(OpBuilder &b, Location l);
192 
193  // Actual Implementation provided by derived class.
194  virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
196  virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
197  llvm_unreachable("Unsupported");
198  }
199  virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
200  virtual Value derefImpl(OpBuilder &b, Location l) = 0;
201  // Gets the ValueRange that together specifies the current position of the
202  // iterator. For a unique level, the position can be a single index points to
203  // the current coordinate being visited. For a non-unique level, an extra
204  // index for the `segment high` is needed to to specifies the range of
205  // duplicated coordinates. The ValueRange should be able to uniquely identify
206  // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
207  //
208  // Not every type of iterator supports the operation, e.g., non-empty
209  // subsection iterator does not because it represent a range of coordinates
210  // instead of just one.
211  virtual ValueRange getCurPosition() const { return getCursor(); };
212 
213  // Returns a pair of values for *upper*, *lower* bound respectively.
214  virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
215  assert(randomAccessible());
216  // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
217  return {getCrd(), upperBound(b, l)};
218  }
219 
220  // Generates a bool value for scf::ConditionOp.
221  std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
222  ValueRange vs) {
223  ValueRange rem = linkNewScope(vs);
224  return std::make_pair(genNotEnd(b, l), rem);
225  }
226 
227  // Generate a conditional it.next() in the following form
228  //
229  // if (cond)
230  // yield it.next
231  // else
232  // yield it
233  //
234  // The function is virtual to allow alternative implementation. For example,
235  // if it.next() is trivial to compute, we can use a select operation instead.
236  // E.g.,
237  //
238  // it = select cond ? it+1 : it
239  virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
240 
241  // Update the SSA value for the iterator after entering a new scope.
243  assert(!randomAccessible() && "random accessible iterators are traversed "
244  "by coordinate, call locate() instead.");
245  seek(pos.take_front(cursorValsCnt));
246  return pos.drop_front(cursorValsCnt);
247  };
248 
249 protected:
250  void updateCrd(Value crd) { this->crd = crd; }
251 
253  MutableArrayRef<Value> ref = cursorValsStorageRef;
254  return ref.take_front(cursorValsCnt);
255  }
256 
257  void inherentBatch(const SparseIterator &parent) {
258  batchCrds = parent.batchCrds;
259  }
260 
263 
264 public:
265  const IterKind kind; // For LLVM-style RTTI.
266  const unsigned tid, lvl; // tensor level identifier.
267 
268 private:
269  Value crd; // The sparse coordinate used to coiterate;
270 
271  // A range of value that together defines the current state of the
272  // iterator. Only loop variants should be included.
273  //
274  // For trivial iterators, it is the position; for dedup iterators, it consists
275  // of the positon and the segment high, for non-empty subsection iterator, it
276  // is the metadata that specifies the subsection.
277  // Note that the wrapped iterator shares the same storage to maintain itVals
278  // with it wrapper, which means the wrapped iterator might only own a subset
279  // of all the values stored in itValStorage.
280  const unsigned cursorValsCnt;
281  SmallVectorImpl<Value> &cursorValsStorageRef;
282 };
283 
284 /// Helper function to create a TensorLevel object from given `tensor`.
285 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
286  Location l, Value t,
287  unsigned tid,
288  Level lvl);
289 
290 /// Helper function to create a simple SparseIterator object that iterate over
291 /// the SparseTensorLevel.
292 std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
293  SparseEmitStrategy strategy);
294 
295 /// Helper function to create a synthetic SparseIterator object that iterates
296 /// over a dense space specified by [0,`sz`).
297 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
298 makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
299  SparseEmitStrategy strategy);
300 
301 /// Helper function to create a SparseIterator object that iterates over a
302 /// sliced space, the orignal space (before slicing) is traversed by `sit`.
303 std::unique_ptr<SparseIterator>
304 makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
305  Value stride, Value size, SparseEmitStrategy strategy);
306 
307 /// Helper function to create a SparseIterator object that iterates over a
308 /// padded sparse level (the padded value must be zero).
309 std::unique_ptr<SparseIterator>
310 makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
311  Value padHigh, SparseEmitStrategy strategy);
312 
313 /// Helper function to create a SparseIterator object that iterate over the
314 /// non-empty subsections set.
315 std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
316  OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
317  std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
318  SparseEmitStrategy strategy);
319 
320 /// Helper function to create a SparseIterator object that iterates over a
321 /// non-empty subsection created by NonEmptySubSectIterator.
322 std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
323  OpBuilder &b, Location l, const SparseIterator &subsectIter,
324  const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
325  Value loopBound, unsigned stride, SparseEmitStrategy strategy);
326 
327 } // namespace sparse_tensor
328 } // namespace mlir
329 
330 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
MutableArrayRef< Value > getMutCursorVals()
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
SparseIterator(IterKind kind, unsigned cursorValsCnt, SmallVectorImpl< Value > &cursorValStorage, const SparseIterator &delegate)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value upperBound(OpBuilder &b, Location l) const =0
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual void deserialize(ValueRange vs)
SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraCursorCnt=0)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
ValueRange linkNewScope(ValueRange pos)
virtual std::string getDebugInterfacePrefix() const =0
virtual SmallVector< Value > serialize() const
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned cursorValsCnt, SmallVectorImpl< Value > &cursorValStorage)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
virtual ValueRange getLvlBuffers() const =0
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
bool isUniqueLT(LevelType lt)
Definition: Enums.h:424
std::string toMLIRString(LevelType lt)
Definition: Enums.h:443
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
std::unique_ptr< SparseIterator > makeSimpleIterator(const SparseTensorLevel &stl, SparseEmitStrategy strategy)
Helper function to create a simple SparseIterator object that iterate over the SparseTensorLevel.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition: Passes.h:51
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238