MLIR 22.0.0git
SparseTensorIterator.h
Go to the documentation of this file.
1//===- SparseTensorIterator.h ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
11
14
15namespace mlir {
16namespace sparse_tensor {
17
18// Forward declaration.
19class SparseIterator;
20
21/// The base class for all types of sparse tensor levels. It provides interfaces
22/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
23/// `peekCrdAt`).
24class SparseTensorLevel {
25 SparseTensorLevel(SparseTensorLevel &&) = delete;
26 SparseTensorLevel(const SparseTensorLevel &) = delete;
27 SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
28 SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
29
30public:
31 virtual ~SparseTensorLevel() = default;
32
33 std::string toString() const {
34 return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
35 std::to_string(lvl) + "]";
36 }
37
38 virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
39 Value iv) const = 0;
40
41 /// Peeks the lower and upper bound to *fully* traverse the level with
42 /// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
43 /// that the immediate parent level is current at. Returns a pair of values
44 /// for *posLo* and *loopHi* respectively.
45 ///
46 /// For a dense level, the *posLo* is the linearized position at beginning,
47 /// while *loopHi* is the largest *coordinate*, it also implies that the
48 /// smallest *coordinate* to start the loop is 0.
49 ///
50 /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
51 /// to load coordinate from the coordinate buffer.
52 virtual std::pair<Value, Value>
54 ValueRange parentPos, Value inPadZone = nullptr) const = 0;
55
56 virtual std::pair<Value, Value>
58 std::pair<Value, Value> parentRange) const {
59 llvm_unreachable("Not Implemented");
60 };
61
62 Level getLevel() const { return lvl; }
63 LevelType getLT() const { return lt; }
64 Value getSize() const { return lvlSize; }
65 virtual ValueRange getLvlBuffers() const = 0;
66
67 //
68 // Level properties
69 //
70 bool isUnique() const { return isUniqueLT(lt); }
71
72protected:
74 : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize) {};
75
76public:
77 const unsigned tid, lvl;
80};
81
90
91/// A `SparseIterationSpace` represents a sparse set of coordinates defined by
92/// (possibly multiple) levels of a specific sparse tensor.
93/// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when
94/// feature complete.
96public:
100
101 // Constructs a N-D iteration space.
102 SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
103 std::pair<Level, Level> lvlRange, ValueRange parentPos);
104
105 // Constructs a 1-D iteration space.
107 Level lvl, ValueRange parentPos)
108 : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos) {};
109
110 bool isUnique() const { return lvls.back()->isUnique(); }
111
112 unsigned getSpaceDim() const { return lvls.size(); }
113
114 // Reconstructs a iteration space directly from the provided ValueRange.
115 static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
116 unsigned tid);
117
118 // The inverse operation of `fromValues`.
121 for (auto &stl : lvls) {
122 llvm::append_range(vals, stl->getLvlBuffers());
123 vals.push_back(stl->getSize());
124 }
125 vals.append({bound.first, bound.second});
126 return vals;
127 }
128
129 const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
133
134 Value getBoundLo() const { return bound.first; }
135 Value getBoundHi() const { return bound.second; }
136
137 // Extract an iterator to iterate over the sparse iteration space.
138 std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
139 Location l) const;
140
141private:
143 std::pair<Value, Value> bound;
144};
145
146/// Helper class that generates loop conditions, etc, to traverse a
147/// sparse tensor level.
148class SparseIterator {
149 SparseIterator(SparseIterator &&) = delete;
150 SparseIterator(const SparseIterator &) = delete;
151 SparseIterator &operator=(SparseIterator &&) = delete;
152 SparseIterator &operator=(const SparseIterator &) = delete;
153
154protected:
155 SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
156 unsigned cursorValsCnt,
157 SmallVectorImpl<Value> &cursorValStorage)
158 : batchCrds(0), kind(kind), tid(tid), lvl(lvl), crd(nullptr),
159 cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {};
160
161 SparseIterator(IterKind kind, unsigned cursorValsCnt,
162 SmallVectorImpl<Value> &cursorValStorage,
163 const SparseIterator &delegate)
164 : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
165 cursorValStorage) {};
166
167 SparseIterator(IterKind kind, const SparseIterator &wrap,
168 unsigned extraCursorCnt = 0)
169 : SparseIterator(kind, wrap.tid, wrap.lvl,
170 extraCursorCnt + wrap.cursorValsCnt,
171 wrap.cursorValsStorageRef) {
172 assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
173 cursorValsStorageRef.append(extraCursorCnt, nullptr);
174 assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
175 };
176
177public:
178 virtual ~SparseIterator() = default;
179
181 emitStrategy = strategy;
182 }
183
185 return emitStrategy;
186 }
187
188 virtual std::string getDebugInterfacePrefix() const = 0;
190
191 Value getCrd() const { return crd; }
194 return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
195 };
196
197 // Sets the iterate to the specified position.
198 void seek(ValueRange vals) {
199 assert(vals.size() == cursorValsCnt);
200 std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
201 // Now that the iterator is re-positioned, the coordinate becomes invalid.
202 crd = nullptr;
203 }
204
205 // Reconstructs a iteration space directly from the provided ValueRange.
206 static std::unique_ptr<SparseIterator>
207 fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
208
209 // The inverse operation of `fromValues`.
210 SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
211
212 //
213 // Iterator properties.
214 //
215
216 // Whether the iterator is a iterator over a batch level.
217 virtual bool isBatchIterator() const = 0;
218
219 // Whether the iterator support random access (i.e., support look up by
220 // *coordinate*). A random access iterator must also traverses a dense space.
221 virtual bool randomAccessible() const = 0;
222
223 // Whether the iterator can simply traversed by a for loop.
224 virtual bool iteratableByFor() const { return false; };
225
226 // Get the upper bound of the sparse space that the iterator might visited. A
227 // sparse space is a subset of a dense space [0, bound), this function returns
228 // *bound*.
229 virtual Value upperBound(OpBuilder &b, Location l) const = 0;
230
231 // Serializes and deserializes the current status to/from a set of values. The
232 // ValueRange should contain values that are sufficient to recover the current
233 // iterating postion (i.e., itVals) as well as loop bound.
234 //
235 // Not every type of iterator supports the operations, e.g., non-empty
236 // subsection iterator does not because the the number of non-empty
237 // subsections can not be determined easily.
238 //
239 // NOTE: All the values should have index type.
241 llvm_unreachable("unsupported");
242 };
243 virtual void deserialize(ValueRange vs) { llvm_unreachable("unsupported"); };
244
245 //
246 // Core functions.
247 //
248
249 // Initializes the iterator according to the parent iterator's state.
250 void genInit(OpBuilder &b, Location l, const SparseIterator *p);
251
252 // Forwards the iterator to the next element.
254
255 // Locate the iterator to the position specified by *crd*, this can only
256 // be done on an iterator that supports randm access.
257 void locate(OpBuilder &b, Location l, Value crd);
258
259 // Returns a boolean value that equals `!it.end()`
261
262 // Dereferences the iterator, loads the coordinate at the current position.
263 //
264 // The method assumes that the iterator is not currently exhausted (i.e.,
265 // it != it.end()).
267
268 // Actual Implementation provided by derived class.
269 virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
271 virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
272 llvm_unreachable("Unsupported");
273 }
276 // Gets the ValueRange that together specifies the current position of the
277 // iterator. For a unique level, the position can be a single index points to
278 // the current coordinate being visited. For a non-unique level, an extra
279 // index for the `segment high` is needed to to specifies the range of
280 // duplicated coordinates. The ValueRange should be able to uniquely identify
281 // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
282 //
283 // Not every type of iterator supports the operation, e.g., non-empty
284 // subsection iterator does not because it represent a range of coordinates
285 // instead of just one.
286 virtual ValueRange getCurPosition() const { return getCursor(); };
287
288 // Returns a pair of values for *upper*, *lower* bound respectively.
289 virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
290 assert(randomAccessible());
291 // Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
292 return {getCrd(), upperBound(b, l)};
293 }
294
295 // Generates a bool value for scf::ConditionOp.
296 std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
297 ValueRange vs) {
299 return std::make_pair(genNotEnd(b, l), rem);
300 }
301
302 // Generate a conditional it.next() in the following form
303 //
304 // if (cond)
305 // yield it.next
306 // else
307 // yield it
308 //
309 // The function is virtual to allow alternative implementation. For example,
310 // if it.next() is trivial to compute, we can use a select operation instead.
311 // E.g.,
312 //
313 // it = select cond ? it+1 : it
314 virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
315
316 // Update the SSA value for the iterator after entering a new scope.
318 assert(!randomAccessible() && "random accessible iterators are traversed "
319 "by coordinate, call locate() instead.");
320 seek(pos.take_front(cursorValsCnt));
321 return pos.drop_front(cursorValsCnt);
322 };
323
324protected:
325 void updateCrd(Value crd) { this->crd = crd; }
326
328 MutableArrayRef<Value> ref = cursorValsStorageRef;
329 return ref.take_front(cursorValsCnt);
330 }
331
332 void inherentBatch(const SparseIterator &parent) {
333 batchCrds = parent.batchCrds;
334 }
335
338
339public:
340 const IterKind kind; // For LLVM-style RTTI.
341 const unsigned tid, lvl; // tensor level identifier.
342
343private:
344 Value crd; // The sparse coordinate used to coiterate;
345
346 // A range of value that together defines the current state of the
347 // iterator. Only loop variants should be included.
348 //
349 // For trivial iterators, it is the position; for dedup iterators, it consists
350 // of the positon and the segment high, for non-empty subsection iterator, it
351 // is the metadata that specifies the subsection.
352 // Note that the wrapped iterator shares the same storage to maintain itVals
353 // with it wrapper, which means the wrapped iterator might only own a subset
354 // of all the values stored in itValStorage.
355 const unsigned cursorValsCnt;
356 SmallVectorImpl<Value> &cursorValsStorageRef;
357};
358
359/// Helper function to create a TensorLevel object from given `tensor`.
360std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
361 Location l, Value t,
362 unsigned tid,
363 Level lvl);
364
365/// Helper function to create a TensorLevel object from given ValueRange.
366std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
367 ValueRange buffers,
368 unsigned tid, Level l);
369
370/// Helper function to create a simple SparseIterator object that iterate
371/// over the entire iteration space.
372std::unique_ptr<SparseIterator>
374 const SparseIterationSpace &iterSpace);
375
376/// Helper function to create a simple SparseIterator object that iterate
377/// over the sparse tensor level.
378/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
379/// feature complete.
380std::unique_ptr<SparseIterator> makeSimpleIterator(
381 const SparseTensorLevel &stl,
383
384/// Helper function to create a synthetic SparseIterator object that iterates
385/// over a dense space specified by [0,`sz`).
386std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
387makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
388 SparseEmitStrategy strategy);
389
390/// Helper function to create a SparseIterator object that iterates over a
391/// sliced space, the orignal space (before slicing) is traversed by `sit`.
392std::unique_ptr<SparseIterator>
393makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
394 Value stride, Value size, SparseEmitStrategy strategy);
395
396/// Helper function to create a SparseIterator object that iterates over a
397/// padded sparse level (the padded value must be zero).
398std::unique_ptr<SparseIterator>
399makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
400 Value padHigh, SparseEmitStrategy strategy);
401
402/// Helper function to create a SparseIterator object that iterate over the
403/// non-empty subsections set.
404std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
405 OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
406 std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
407 SparseEmitStrategy strategy);
408
409/// Helper function to create a SparseIterator object that iterates over a
410/// non-empty subsection created by NonEmptySubSectIterator.
411std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
412 OpBuilder &b, Location l, const SparseIterator &subsectIter,
413 const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
414 Value loopBound, unsigned stride, SparseEmitStrategy strategy);
415
416} // namespace sparse_tensor
417} // namespace mlir
418
419#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
#define rem(a, b)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
SparseIterationSpace(SparseIterationSpace &)=delete
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
SparseIterationSpace(SparseIterationSpace &&)=default
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, Level lvl, ValueRange parentPos)
const SparseTensorLevel & getLastLvl() const
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
SparseIterator(IterKind kind, unsigned cursorValsCnt, SmallVectorImpl< Value > &cursorValStorage, const SparseIterator &delegate)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
virtual Value upperBound(OpBuilder &b, Location l) const =0
virtual Value derefImpl(OpBuilder &b, Location l)=0
virtual void setSparseEmitStrategy(SparseEmitStrategy strategy)
Value genNotEnd(OpBuilder &b, Location l)
MutableArrayRef< Value > getMutCursorVals()
void locate(OpBuilder &b, Location l, Value crd)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
virtual void deserialize(ValueRange vs)
SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraCursorCnt=0)
virtual SmallVector< Value > serialize() const
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual SparseEmitStrategy getSparseEmitStrategy() const
SmallVector< Value > toValues() const
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
ValueRange linkNewScope(ValueRange pos)
virtual std::string getDebugInterfacePrefix() const =0
static std::unique_ptr< SparseIterator > fromValues(IteratorType dstTp, ValueRange values, unsigned tid)
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned cursorValsCnt, SmallVectorImpl< Value > &cursorValStorage)
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, std::pair< Value, Value > parentRange) const
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
virtual ValueRange getLvlBuffers() const =0
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition Diagnostics.h:24
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
std::string toMLIRString(LevelType lt)
Definition Enums.h:447
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
Definition Passes.h:52
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238