13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
22 namespace sparse_tensor {
27 : specifier(cast<
TypedValue<StorageSpecifierType>>(specifier)) {}
33 operator Value() {
return specifier; }
36 StorageSpecifierKind kind, std::optional<Level> lvl);
39 StorageSpecifierKind kind, std::optional<Level> lvl);
50 template <
typename ValueArrayRef>
58 static_assert(std::is_trivially_copyable_v<
64 std::optional<Level> lvl)
const {
78 StorageSpecifierKind kind,
79 std::optional<Level> lvl)
const {
97 std::optional<Level> lvl)
const {
102 assert(fidx <
fields.size() - 1);
122 std::optional<Level> lvl)
const {
127 assert(fidx <
fields.size());
132 return fields.drop_back();
194 assert(fidx <
fields.size() - 1);
199 assert(fidx <
fields.size());
206 StorageSpecifierKind kind, std::optional<Level> lvl,
233 return llvm::cast<UnrealizedConversionCastOp>(tensor.
getDefiningOp());
239 return builder.
create<UnrealizedConversionCastOp>(loc,
TypeRange(tp), values)
248 inline SparseTensorDescriptor
254 inline MutSparseTensorDescriptor
256 RankedTensorType type) {
258 fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setMemRefField(FieldIndex fidx, Value v)
MutSparseTensorDescriptor(SparseTensorType stt, SmallVectorImpl< Value > &buffers)
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
A helper class around an array of values that corresponds to a sparse tensor.
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getMemRefField(FieldIndex fidx) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
SparseTensorSpecifier(Value specifier)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields, RankedTensorType type)
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor, RankedTensorType type)
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.