MLIR 22.0.0git
SparseTensorDescriptor.h
Go to the documentation of this file.
1//===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for the sparse memory layout.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
14#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
15
20
21namespace mlir {
22namespace sparse_tensor {
23
25public:
26 explicit SparseTensorSpecifier(Value specifier)
27 : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
28
29 // Undef value for level-sizes, all zero values for memory-sizes.
30 static Value getInitValue(OpBuilder &builder, Location loc,
32
33 /*implicit*/ operator Value() { return specifier; }
34
36 StorageSpecifierKind kind, std::optional<Level> lvl);
37
38 void setSpecifierField(OpBuilder &builder, Location loc, Value v,
39 StorageSpecifierKind kind, std::optional<Level> lvl);
40
41private:
43};
44
45/// A helper class around an array of values that corresponds to a sparse
46/// tensor. This class provides a set of meaningful APIs to query and update
47/// a particular field in a consistent way. Users should not make assumptions
48/// on how a sparse tensor is laid out but instead rely on this class to access
49/// the right value for the right field.
50template <typename ValueArrayRef>
52protected:
54 : rType(stt), fields(fields), layout(stt) {
55 assert(layout.getNumFields() == getNumFields());
56 // We should make sure the class is trivially copyable (and should be small
57 // enough) such that we can pass it by value.
58 static_assert(std::is_trivially_copyable_v<
60 }
61
62public:
64 std::optional<Level> lvl) const {
65 // Delegates to storage layout.
66 return layout.getMemRefFieldIndex(kind, lvl);
67 }
68
69 unsigned getNumFields() const { return fields.size(); }
70
71 ///
72 /// Getters: get the value for required field.
73 ///
74
75 Value getSpecifier() const { return fields.back(); }
76
78 StorageSpecifierKind kind,
79 std::optional<Level> lvl) const {
81 return md.getSpecifierField(builder, loc, kind, lvl);
82 }
83
84 Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const {
85 return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl);
86 }
87
91
94 }
95
97 std::optional<Level> lvl) const {
98 return getField(getMemRefFieldIndex(kind, lvl));
99 }
100
102 assert(fidx < fields.size() - 1);
103 return getField(fidx);
104 }
105
106 Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const {
107 return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize,
108 lvl);
109 }
110
111 Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const {
112 return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize,
113 lvl);
114 }
115
116 Value getValMemSize(OpBuilder &builder, Location loc) const {
117 return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
118 std::nullopt);
119 }
120
122 std::optional<Level> lvl) const {
123 return getMemRefType(getMemRefField(kind, lvl)).getElementType();
124 }
125
127 assert(fidx < fields.size());
128 return fields[fidx];
129 }
130
132 return fields.drop_back(); // drop the last metadata fields
133 }
134
135 std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
136 return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
137 }
138
140 const Level cooStart = rType.getAoSCOOStart();
141 assert(cooStart < rType.getLvlRank());
143 }
144
145 RankedTensorType getRankedTensorType() const { return rType; }
146 ValueArrayRef getFields() const { return fields; }
147 StorageLayout getLayout() const { return layout; }
148
149protected:
151 ValueArrayRef fields;
153};
154
155/// Uses ValueRange for immutable descriptors.
157public:
160
161 Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const;
162};
163
164/// Using SmallVector for mutable descriptor allows users to reuse it as a
165/// tmp buffers to append value for some special cases, though users should
166/// be responsible to restore the buffer to legal states after their use. It
167/// is probably not a clean way, but it is the most efficient way to avoid
168/// copying the fields into another SmallVector. If a more clear way is
169/// wanted, we should change it to MutableArrayRef instead.
171 : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
172public:
176
177 // Allow implicit type conversion from mutable descriptors to immutable ones
178 // (but not vice versa).
179 /*implicit*/ operator SparseTensorDescriptor() const {
181 }
182
183 ///
184 /// Adds additional setters for mutable descriptor, update the value for
185 /// required field.
186 ///
187
188 void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl,
189 Value v) {
190 fields[getMemRefFieldIndex(kind, lvl)] = v;
191 }
192
194 assert(fidx < fields.size() - 1);
195 fields[fidx] = v;
196 }
197
198 void setField(FieldIndex fidx, Value v) {
199 assert(fidx < fields.size());
200 fields[fidx] = v;
201 }
202
203 void setSpecifier(Value newSpec) { fields.back() = newSpec; }
204
206 StorageSpecifierKind kind, std::optional<Level> lvl,
207 Value v) {
208 SparseTensorSpecifier md(fields.back());
209 md.setSpecifierField(builder, loc, v, kind, lvl);
210 fields.back() = md;
211 }
212
213 void setValMemSize(OpBuilder &builder, Location loc, Value v) {
214 setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
215 std::nullopt, v);
216 }
217
218 void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
219 setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v);
220 }
221
222 void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
223 setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v);
224 }
225
226 void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
227 setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v);
228 }
229};
230
231/// Packs the given values as a "tuple" value.
232inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
233 ValueRange values) {
234 return UnrealizedConversionCastOp::create(builder, loc, TypeRange(tp), values)
235 .getResult(0);
236}
237
238inline Value genTuple(OpBuilder &builder, Location loc,
240 return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
241}
242
243inline SparseTensorDescriptor
244getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) {
245 return SparseTensorDescriptor(SparseTensorType(type), adaptorValues);
246}
247
248inline MutSparseTensorDescriptor
251 RankedTensorType type) {
252 fields.assign(adaptorValues.begin(), adaptorValues.end());
253 return MutSparseTensorDescriptor(SparseTensorType(type), fields);
254}
255
256} // namespace sparse_tensor
257} // namespace mlir
258
259#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
MutSparseTensorDescriptor(SparseTensorType stt, SmallVectorImpl< Value > &buffers)
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
Uses ValueRange for immutable descriptors.
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Provides methods to access fields of a sparse tensor with the given encoding.
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497