9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
16 namespace sparse_tensor {
35 std::to_string(
lvl) +
"]";
52 virtual std::pair<Value, Value>
56 virtual std::pair<Value, Value>
58 std::pair<Value, Value> parentRange)
const {
59 llvm_unreachable(
"Not Implemented");
101 std::pair<Level, Level> lvlRange,
ValueRange parentPos);
108 bool isUnique()
const {
return lvls.back()->isUnique(); }
119 for (
auto &stl : lvls) {
120 llvm::append_range(vals, stl->getLvlBuffers());
121 vals.push_back(stl->getSize());
123 vals.append({bound.first, bound.second});
141 std::pair<Value, Value> bound;
154 unsigned cursorValsCnt,
157 cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage) {};
163 cursorValStorage) {};
166 unsigned extraCursorCnt = 0)
168 extraCursorCnt +
wrap.cursorValsCnt,
169 wrap.cursorValsStorageRef) {
170 assert(
wrap.cursorValsCnt ==
wrap.cursorValsStorageRef.size());
171 cursorValsStorageRef.append(extraCursorCnt,
nullptr);
172 assert(cursorValsStorageRef.size() ==
wrap.cursorValsCnt + extraCursorCnt);
188 return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
193 assert(vals.size() == cursorValsCnt);
194 std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
200 static std::unique_ptr<SparseIterator>
235 llvm_unreachable(
"unsupported");
266 llvm_unreachable(
"Unsupported");
293 return std::make_pair(
genNotEnd(b, l), rem);
313 "by coordinate, call locate() instead.");
314 seek(pos.take_front(cursorValsCnt));
315 return pos.drop_front(cursorValsCnt);
323 return ref.take_front(cursorValsCnt);
349 const unsigned cursorValsCnt;
362 unsigned tid,
Level l);
366 std::unique_ptr<SparseIterator>
380 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
386 std::unique_ptr<SparseIterator>
392 std::unique_ptr<SparseIterator>
400 std::unique_ptr<SparseIterator> &&delegate,
Value size,
unsigned stride,
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
SparseIterationSpace()=default
unsigned getSpaceDim() const
const SparseTensorLevel & getLastLvl() const
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
std::unique_ptr< SparseIterator > extractIterator(OpBuilder &b, Location l) const
SmallVector< Value > toValues() const
ArrayRef< std::unique_ptr< SparseTensorLevel > > getLvlRef() const
SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid, Level lvl, ValueRange parentPos)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
void updateCrd(Value crd)
MutableArrayRef< Value > getMutCursorVals()
virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *)=0
ValueRange forward(OpBuilder &b, Location l)
SmallVector< Value > toValues() const
SparseIterator(IterKind kind, unsigned cursorValsCnt, SmallVectorImpl< Value > &cursorValStorage, const SparseIterator &delegate)
void setSparseEmitStrategy(SparseEmitStrategy strategy)
virtual bool isBatchIterator() const =0
virtual void locateImpl(OpBuilder &b, Location l, Value crd)
void genInit(OpBuilder &b, Location l, const SparseIterator *p)
SparseEmitStrategy emitStrategy
virtual Value upperBound(OpBuilder &b, Location l) const =0
ValueRange getBatchCrds() const
virtual ~SparseIterator()=default
virtual Value derefImpl(OpBuilder &b, Location l)=0
Value genNotEnd(OpBuilder &b, Location l)
void locate(OpBuilder &b, Location l, Value crd)
virtual void deserialize(ValueRange vs)
SmallVector< Value > batchCrds
SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraCursorCnt=0)
static std::unique_ptr< SparseIterator > fromValues(IteratorType dstTp, ValueRange values, unsigned tid)
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond)
virtual Value genNotEndImpl(OpBuilder &b, Location l)=0
void inherentBatch(const SparseIterator &parent)
ValueRange linkNewScope(ValueRange pos)
virtual std::string getDebugInterfacePrefix() const =0
ValueRange getCursor() const
virtual bool iteratableByFor() const
virtual SmallVector< Value > serialize() const
virtual ValueRange getCurPosition() const
Value deref(OpBuilder &b, Location l)
virtual bool randomAccessible() const =0
virtual ValueRange forwardImpl(OpBuilder &b, Location l)=0
void seek(ValueRange vals)
virtual SmallVector< Type > getCursorValTypes(OpBuilder &b) const =0
SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned cursorValsCnt, SmallVectorImpl< Value > &cursorValStorage)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
The base class for all types of sparse tensor levels.
virtual std::pair< Value, Value > peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix, ValueRange parentPos, Value inPadZone=nullptr) const =0
Peeks the lower and upper bound to fully traverse the level with the given position parentPos,...
virtual std::pair< Value, Value > collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix, std::pair< Value, Value > parentRange) const
virtual Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix, Value iv) const =0
std::string toString() const
virtual ~SparseTensorLevel()=default
virtual ValueRange getLvlBuffers() const =0
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
bool isUniqueLT(LevelType lt)
std::string toMLIRString(LevelType lt)
std::unique_ptr< SparseTensorLevel > makeSparseTensorLevel(OpBuilder &b, Location l, Value t, unsigned tid, Level lvl)
Helper function to create a TensorLevel object from given tensor.
std::unique_ptr< SparseIterator > makeTraverseSubSectIterator(OpBuilder &b, Location l, const SparseIterator &subsectIter, const SparseIterator &parent, std::unique_ptr< SparseIterator > &&wrap, Value loopBound, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a non-empty subsection created b...
uint64_t Level
The type of level identifiers and level-ranks.
std::pair< std::unique_ptr< SparseTensorLevel >, std::unique_ptr< SparseIterator > > makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl, SparseEmitStrategy strategy)
Helper function to create a synthetic SparseIterator object that iterates over a dense space specifie...
std::unique_ptr< SparseIterator > makePaddedIterator(std::unique_ptr< SparseIterator > &&sit, Value padLow, Value padHigh, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a padded sparse level (the padde...
std::unique_ptr< SparseIterator > makeSimpleIterator(OpBuilder &b, Location l, const SparseIterationSpace &iterSpace)
Helper function to create a simple SparseIterator object that iterate over the entire iteration space...
std::unique_ptr< SparseIterator > makeSlicedLevelIterator(std::unique_ptr< SparseIterator > &&sit, Value offset, Value stride, Value size, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterates over a sliced space,...
std::unique_ptr< SparseIterator > makeNonEmptySubSectIterator(OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound, std::unique_ptr< SparseIterator > &&delegate, Value size, unsigned stride, SparseEmitStrategy strategy)
Helper function to create a SparseIterator object that iterate over the non-empty subsections set.
Include the generated interface declarations.
SparseEmitStrategy
Defines a scope for reinterpret map pass.
This enum defines all the sparse representations supportable by the SparseTensor dialect.