MLIR  19.0.0git
Public Member Functions | List of all members
mlir::sparse_tensor::StorageLayout Class Reference

Provides methods to access fields of a sparse tensor with the given encoding. More...

#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"

Public Member Functions

 StorageLayout (const SparseTensorType &stt)
 
 StorageLayout (SparseTensorEncodingAttr enc)
 
void foreachField (llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
 For each field that will be allocated for the given sparse tensor encoding, calls the callback with the corresponding field index, field kind, level, and level-type (the last two are only for level memrefs). More...
 
FieldIndex getMemRefFieldIndex (SparseTensorFieldKind kind, std::optional< Level > lvl) const
 Gets the field index for required field. More...
 
unsigned getNumFields () const
 Gets the total number of fields for the given sparse tensor encoding. More...
 
unsigned getNumDataFields () const
 Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the given sparse tensor encoding. More...
 
std::pair< FieldIndex, unsigned > getFieldIndexAndStride (SparseTensorFieldKind kind, std::optional< Level > lvl) const
 

Detailed Description

Provides methods to access fields of a sparse tensor with the given encoding.

Definition at line 114 of file SparseTensorStorageLayout.h.

Constructor & Destructor Documentation

◆ StorageLayout() [1/2]

mlir::sparse_tensor::StorageLayout::StorageLayout ( const SparseTensorType stt)
inlineexplicit

Definition at line 116 of file SparseTensorStorageLayout.h.

◆ StorageLayout() [2/2]

mlir::sparse_tensor::StorageLayout::StorageLayout ( SparseTensorEncodingAttr  enc)
inlineexplicit

Definition at line 118 of file SparseTensorStorageLayout.h.

Member Function Documentation

◆ foreachField()

void StorageLayout::foreachField ( llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>  callback) const

For each field that will be allocated for the given sparse tensor encoding, calls the callback with the corresponding field index, field kind, level, and level-type (the last two are only for level memrefs).

The field index always starts with zero and increments by one between each callback invocation. Ideally, all other methods should rely on this function to query a sparse tensor fields instead of relying on ad-hoc index computation.

Definition at line 100 of file SparseTensorDialect.cpp.

References mlir::sparse_tensor::CrdMemRef, mlir::sparse_tensor::SparseTensorType::getCOOSegments(), mlir::sparse_tensor::isWithCrdLT(), mlir::sparse_tensor::isWithPosLT(), kDataFieldStartingIdx, kInvalidLevel, mlir::sparse_tensor::PosMemRef, mlir::sparse_tensor::StorageSpec, mlir::sparse_tensor::Undef, and mlir::sparse_tensor::ValMemRef.

Referenced by mlir::sparse_tensor::foreachFieldAndTypeInSparseTensor(), mlir::sparse_tensor::foreachFieldInSparseTensor(), getNumDataFields(), and getNumFields().

◆ getFieldIndexAndStride()

std::pair< FieldIndex, unsigned > StorageLayout::getFieldIndexAndStride ( SparseTensorFieldKind  kind,
std::optional< Level lvl 
) const

◆ getMemRefFieldIndex()

FieldIndex mlir::sparse_tensor::StorageLayout::getMemRefFieldIndex ( SparseTensorFieldKind  kind,
std::optional< Level lvl 
) const
inline

Gets the field index for required field.

Definition at line 136 of file SparseTensorStorageLayout.h.

◆ getNumDataFields()

unsigned StorageLayout::getNumDataFields ( ) const

Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the given sparse tensor encoding.

Definition at line 193 of file SparseTensorDialect.cpp.

References foreachField(), getNumFields(), and kDataFieldStartingIdx.

Referenced by mlir::sparse_tensor::getNumDataFieldsFromEncoding().

◆ getNumFields()

unsigned StorageLayout::getNumFields ( ) const

Gets the total number of fields for the given sparse tensor encoding.

Definition at line 183 of file SparseTensorDialect.cpp.

References foreachField().

Referenced by getNumDataFields(), mlir::sparse_tensor::getNumFieldsFromEncoding(), and mlir::sparse_tensor::SparseTensorDescriptorImpl< ValueArrayRef >::SparseTensorDescriptorImpl().


The documentation for this class was generated from the following files: