MLIR  19.0.0git
SparseTensorDescriptor.h
Go to the documentation of this file.
1 //===- SparseTensorDescriptor.h ---------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for the sparse memory layout.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORDESCRIPTOR_H_
15 
20 
21 namespace mlir {
22 namespace sparse_tensor {
23 
25 public:
26  explicit SparseTensorSpecifier(Value specifier)
27  : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
28 
29  // Undef value for level-sizes, all zero values for memory-sizes.
30  static Value getInitValue(OpBuilder &builder, Location loc,
31  SparseTensorType stt);
32 
33  /*implicit*/ operator Value() { return specifier; }
34 
36  StorageSpecifierKind kind, std::optional<Level> lvl);
37 
38  void setSpecifierField(OpBuilder &builder, Location loc, Value v,
39  StorageSpecifierKind kind, std::optional<Level> lvl);
40 
41 private:
43 };
44 
45 /// A helper class around an array of values that corresponds to a sparse
46 /// tensor. This class provides a set of meaningful APIs to query and update
47 /// a particular field in a consistent way. Users should not make assumptions
48 /// on how a sparse tensor is laid out but instead rely on this class to access
49 /// the right value for the right field.
50 template <typename ValueArrayRef>
52 protected:
54  : rType(stt), fields(fields), layout(stt) {
55  assert(layout.getNumFields() == getNumFields());
56  // We should make sure the class is trivially copyable (and should be small
57  // enough) such that we can pass it by value.
58  static_assert(std::is_trivially_copyable_v<
60  }
61 
62 public:
64  std::optional<Level> lvl) const {
65  // Delegates to storage layout.
66  return layout.getMemRefFieldIndex(kind, lvl);
67  }
68 
69  unsigned getNumFields() const { return fields.size(); }
70 
71  ///
72  /// Getters: get the value for required field.
73  ///
74 
75  Value getSpecifier() const { return fields.back(); }
76 
78  StorageSpecifierKind kind,
79  std::optional<Level> lvl) const {
80  SparseTensorSpecifier md(fields.back());
81  return md.getSpecifierField(builder, loc, kind, lvl);
82  }
83 
84  Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const {
85  return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl);
86  }
87 
88  Value getPosMemRef(Level lvl) const {
90  }
91 
92  Value getValMemRef() const {
94  }
95 
97  std::optional<Level> lvl) const {
98  return getField(getMemRefFieldIndex(kind, lvl));
99  }
100 
102  assert(fidx < fields.size() - 1);
103  return getField(fidx);
104  }
105 
106  Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const {
107  return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize,
108  lvl);
109  }
110 
111  Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const {
112  return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize,
113  lvl);
114  }
115 
116  Value getValMemSize(OpBuilder &builder, Location loc) const {
117  return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
118  std::nullopt);
119  }
120 
122  std::optional<Level> lvl) const {
123  return getMemRefType(getMemRefField(kind, lvl)).getElementType();
124  }
125 
126  Value getField(FieldIndex fidx) const {
127  assert(fidx < fields.size());
128  return fields[fidx];
129  }
130 
132  return fields.drop_back(); // drop the last metadata fields
133  }
134 
135  std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
137  }
138 
139  Value getAOSMemRef() const {
140  const Level cooStart = rType.getAoSCOOStart();
141  assert(cooStart < rType.getLvlRank());
143  }
144 
145  RankedTensorType getRankedTensorType() const { return rType; }
146  ValueArrayRef getFields() const { return fields; }
147  StorageLayout getLayout() const { return layout; }
148 
149 protected:
151  ValueArrayRef fields;
153 };
154 
155 /// Uses ValueRange for immutable descriptors.
157 public:
159  : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {}
160 
161  Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const;
162 };
163 
164 /// Using SmallVector for mutable descriptor allows users to reuse it as a
165 /// tmp buffers to append value for some special cases, though users should
166 /// be responsible to restore the buffer to legal states after their use. It
167 /// is probably not a clean way, but it is the most efficient way to avoid
168 /// copying the fields into another SmallVector. If a more clear way is
169 /// wanted, we should change it to MutableArrayRef instead.
171  : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
172 public:
174  SmallVectorImpl<Value> &buffers)
175  : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {}
176 
177  // Allow implicit type conversion from mutable descriptors to immutable ones
178  // (but not vice versa).
179  /*implicit*/ operator SparseTensorDescriptor() const {
181  }
182 
183  ///
184  /// Adds additional setters for mutable descriptor, update the value for
185  /// required field.
186  ///
187 
188  void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl,
189  Value v) {
190  fields[getMemRefFieldIndex(kind, lvl)] = v;
191  }
192 
194  assert(fidx < fields.size() - 1);
195  fields[fidx] = v;
196  }
197 
198  void setField(FieldIndex fidx, Value v) {
199  assert(fidx < fields.size());
200  fields[fidx] = v;
201  }
202 
203  void setSpecifier(Value newSpec) { fields.back() = newSpec; }
204 
206  StorageSpecifierKind kind, std::optional<Level> lvl,
207  Value v) {
208  SparseTensorSpecifier md(fields.back());
209  md.setSpecifierField(builder, loc, v, kind, lvl);
210  fields.back() = md;
211  }
212 
213  void setValMemSize(OpBuilder &builder, Location loc, Value v) {
214  setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
215  std::nullopt, v);
216  }
217 
218  void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
219  setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v);
220  }
221 
222  void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
223  setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v);
224  }
225 
226  void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
227  setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v);
228  }
229 };
230 
231 /// Returns the "tuple" value of the adapted tensor.
232 inline UnrealizedConversionCastOp getTuple(Value tensor) {
233  return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
234 }
235 
236 /// Packs the given values as a "tuple" value.
237 inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
238  ValueRange values) {
239  return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
240  .getResult(0);
241 }
242 
243 inline Value genTuple(OpBuilder &builder, Location loc,
244  SparseTensorDescriptor desc) {
245  return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
246 }
247 
249  auto tuple = getTuple(tensor);
250  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
251  return SparseTensorDescriptor(stt, tuple.getInputs());
252 }
253 
254 inline MutSparseTensorDescriptor
256  auto tuple = getTuple(tensor);
257  fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
258  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
259  return MutSparseTensorDescriptor(stt, fields);
260 }
261 
262 } // namespace sparse_tensor
263 } // namespace mlir
264 
265 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSODESCRIPTOR_H_
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
MutSparseTensorDescriptor(SparseTensorType stt, SmallVectorImpl< Value > &buffers)
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
A helper class around an array of values that corresponds to a sparse tensor.
Value getSpecifier() const
Getters: get the value for required field.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
Uses ValueRange for immutable descriptors.
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:82
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor)
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields)
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494