MLIR 22.0.0git
SparseTensorType.h
Go to the documentation of this file.
1//===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header defines the `SparseTensorType` wrapper class.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
14#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
15
17
18namespace mlir {
19namespace sparse_tensor {
20
21//===----------------------------------------------------------------------===//
22/// A wrapper around `RankedTensorType`, which has three goals:
23///
24/// (1) To provide a uniform API for querying aspects of sparse-tensor
25/// types; in particular, to make the "dimension" vs "level" distinction
26/// overt (i.e., explicit everywhere). Thus, throughout the sparsifier
27/// this class should be preferred over using `RankedTensorType` or
28/// `ShapedType` directly, since the methods of the latter do not make
29/// the "dimension" vs "level" distinction overt.
30///
31/// (2) To provide a uniform abstraction over both sparse-tensor
32/// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`)
33/// and dense-tensor types (i.e., `RankedTensorType` without an encoding).
34/// That is, we want to manipulate dense-tensor types using the same API
35/// that we use for manipulating sparse-tensor types; both to keep the
36/// "dimension" vs "level" distinction overt, and to avoid needing to
37/// handle certain cases specially in the sparsifier.
38///
39/// (3) To provide uniform handling of "defaults". In particular
40/// this means that dense-tensors should always return the same answers
41/// as sparse-tensors with a default encoding. But it additionally means
42/// that the answers should be normalized, so that there's no way to
43/// distinguish between non-provided data (which is filled in by default)
44/// vs explicitly-provided data which equals the defaults.
45///
47public:
48 // We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating
49 // the conditionals throughout the rest of the class.
50 SparseTensorType(RankedTensorType rtp)
51 : rtp(rtp), enc(getSparseTensorEncoding(rtp)),
52 lvlRank(enc ? enc.getLvlRank() : getDimRank()),
53 dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()),
54 lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) {
55 assert(rtp && "got null RankedTensorType");
56 assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
57 }
58
59 SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc)
61 RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
62
65
66 //
67 // Factory methods to construct a new `SparseTensorType`
68 // with the same dimension-shape and element type.
69 //
70
71 SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
72 return SparseTensorType(rtp, newEnc);
73 }
74
76 return withEncoding(enc.withDimToLvl(dimToLvl));
77 }
78
79 SparseTensorType withDimToLvl(SparseTensorEncodingAttr dimToLvlEnc) const {
80 return withEncoding(enc.withDimToLvl(dimToLvlEnc));
81 }
82
84 return withDimToLvl(dimToLvlSTT.getEncoding());
85 }
86
88 return withEncoding(enc.withoutDimToLvl());
89 }
90
91 SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
92 return withEncoding(enc.withBitWidths(posWidth, crdWidth));
93 }
94
96 return withEncoding(enc.withoutBitWidths());
97 }
98
100 return withEncoding(enc.withExplicitVal(explicitVal));
101 }
102
104 return withEncoding(enc.withoutExplicitVal());
105 }
106
108 return withEncoding(enc.withImplicitVal(implicitVal));
109 }
110
112 return withEncoding(enc.withoutImplicitVal());
113 }
114
117 return withEncoding(enc.withDimSlices(dimSlices));
118 }
119
121 return withEncoding(enc.withoutDimSlices());
122 }
123
124 /// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
125 /// and `Type`. These are implicit to help alleviate the impedance
126 /// mismatch for code that has not been converted to use `SparseTensorType`
127 /// directly. Once more uses have been converted to `SparseTensorType`,
128 /// we may want to make these explicit instead.
129 ///
130 /// WARNING: This user-defined-conversion method causes overload
131 /// ambiguity whenever passing a `SparseTensorType` directly to a
132 /// function which is overloaded to accept either `Type` or `TypeRange`.
133 /// In particular, this includes `RewriterBase::replaceOpWithNewOp<OpTy>`
134 /// and `OpTy::create` whenever the `OpTy::build` is overloaded
135 /// thus. This happens because the `TypeRange<T>(T&&)` ctor is implicit
136 /// as well, and there's no SFINAE we can add to this method that would
137 /// block subsequent application of that ctor. The only way to fix the
138 /// overload ambiguity is to avoid *implicit* conversion at the callsite:
139 /// e.g., by using `static_cast` to make the conversion explicit, by
140 /// assigning the `SparseTensorType` to a temporary variable of the
141 /// desired type, etc.
142 //
143 // NOTE: We implement this as a single templated user-defined-conversion
144 // function to avoid ambiguity problems when the desired result is `Type`
145 // (since both `RankedTensorType` and `ShapedType` can be implicitly
146 // converted to `Type`).
147 template <typename T, typename = std::enable_if_t<
148 std::is_convertible_v<RankedTensorType, T>>>
149 /*implicit*/ operator T() const {
150 return rtp;
151 }
152
153 /// Explicitly convert to `RankedTensorType`. This method is
154 /// a convenience for resolving overload-ambiguity issues with
155 /// implicit conversion.
156 RankedTensorType getRankedTensorType() const { return rtp; }
157
158 bool operator==(const SparseTensorType &other) const {
159 // All other fields are derived from `rtp` and therefore don't need
160 // to be checked.
161 return rtp == other.rtp;
162 }
163
164 bool operator!=(const SparseTensorType &other) const {
165 return !(*this == other);
166 }
167
168 MLIRContext *getContext() const { return rtp.getContext(); }
169
170 Type getElementType() const { return rtp.getElementType(); }
171
172 SparseTensorEncodingAttr getEncoding() const { return enc; }
173
174 //
175 // SparseTensorEncodingAttr delegators
176 //
177
178 /// Returns true for tensors which have an encoding, and false for
179 /// those which do not. Therefore tensors with an all-dense encoding
180 /// return true.
181 bool hasEncoding() const { return static_cast<bool>(enc); }
182
183 /// Returns true for tensors where every level is dense.
184 /// (This is always true for dense-tensors.)
185 bool isAllDense() const { return enc.isAllDense(); }
186
187 /// Returns true for tensors where every level is ordered.
188 /// (This is always true for dense-tensors.)
189 bool isAllOrdered() const { return enc.isAllOrdered(); }
190
191 /// Translates between level / dimension coordinate space.
193 CrdTransDirectionKind dir) const {
194 return enc.translateCrds(builder, loc, crds, dir);
195 }
196
197 /// Returns true if the dimToLvl mapping is a permutation.
198 /// (This is always true for dense-tensors.)
199 bool isPermutation() const { return enc.isPermutation(); }
200
201 /// Returns true if the dimToLvl mapping is the identity.
202 /// (This is always true for dense-tensors.)
203 bool isIdentity() const { return enc.isIdentity(); }
204
205 //
206 // Other methods.
207 //
208
209 /// Returns the dimToLvl mapping (or the null-map for the identity).
210 /// If you intend to compare the results of this method for equality,
211 /// see `hasSameDimToLvl` instead.
212 AffineMap getDimToLvl() const { return dimToLvl; }
213
214 /// Returns the lvlToDiml mapping (or the null-map for the identity).
215 AffineMap getLvlToDim() const { return lvlToDim; }
216
217 /// Returns the dimToLvl mapping, where the identity map is expanded out
218 /// into a full `AffineMap`. This method is provided as a convenience,
219 /// but for most purposes other methods (`isIdentity`, `getDimToLvl`,
220 /// etc) will be more helpful.
222 return dimToLvl
223 ? dimToLvl
225 }
226
227 /// Returns true iff the two types have the same mapping. This method
228 /// takes care to handle identity maps properly, so it should be preferred
229 /// over using `getDimToLvl` followed by `AffineMap::operator==`.
230 bool hasSameDimToLvl(const SparseTensorType &other) const {
231 // If the maps are the identity, then we need to check the rank
232 // to be sure they're the same size identity. (And since identity
233 // means dimRank==lvlRank, we use lvlRank as a minor optimization.)
234 return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank)
235 : (dimToLvl == other.dimToLvl);
236 }
237
238 /// Returns the dimension-rank.
239 Dimension getDimRank() const { return rtp.getRank(); }
240
241 /// Returns the level-rank.
242 Level getLvlRank() const { return lvlRank; }
243
244 /// Returns the dimension-shape.
245 ArrayRef<Size> getDimShape() const { return rtp.getShape(); }
246
247 /// Returns the level-shape.
249 return getEncoding().translateShape(getDimShape(),
250 CrdTransDirectionKind::dim2lvl);
251 }
252
253 /// Returns the batched level-rank.
254 unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); }
255
256 /// Returns the batched level-shape.
258 auto lvlShape = getEncoding().translateShape(
259 getDimShape(), CrdTransDirectionKind::dim2lvl);
260 lvlShape.truncate(getEncoding().getBatchLvlRank());
261 return lvlShape;
262 }
263
264 /// Returns the type with an identity mapping.
265 RankedTensorType getDemappedType() const {
266 return RankedTensorType::get(getLvlShape(), getElementType(),
267 enc.withoutDimToLvl());
268 }
269
270 /// Safely looks up the requested dimension-DynSize. If you intend
271 /// to check the result with `ShapedType::isDynamic`, then see the
272 /// `getStaticDimSize` method instead.
274 assert(d < getDimRank() && "Dimension is out of bounds");
275 return getDimShape()[d];
276 }
277
278 /// Returns true if no dimension has dynamic size.
279 bool hasStaticDimShape() const { return rtp.hasStaticShape(); }
280
281 /// Returns true if any dimension has dynamic size.
282 bool hasDynamicDimShape() const { return !hasStaticDimShape(); }
283
284 /// Returns true if the given dimension has dynamic size. If you
285 /// intend to call `getDynamicDimSize` based on the result, then see
286 /// the `getStaticDimSize` method instead.
287 bool isDynamicDim(Dimension d) const {
288 // We don't use `rtp.isDynamicDim(d)` because we want the
289 // OOB error message to be consistent with `getDynamicDimSize`.
290 return ShapedType::isDynamic(getDynamicDimSize(d));
291 }
292
293 /// Returns the number of dimensions which have dynamic sizes.
294 /// The return type is `int64_t` to maintain consistency with
295 /// `ShapedType::Trait<T>::getNumDynamicDims`.
296 size_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
297
298 ArrayRef<LevelType> getLvlTypes() const { return enc.getLvlTypes(); }
300 // This OOB check is for dense-tensors, since this class knows
301 // their lvlRank (whereas STEA::getLvlType will/can only check
302 // OOB for sparse-tensors).
303 assert(l < lvlRank && "Level out of bounds");
304 return enc.getLvlType(l);
305 }
306
307 // We can't just delegate these, since we want to use this class's
308 // `getLvlType` method instead of STEA's.
309 bool isDenseLvl(Level l) const { return isDenseLT(getLvlType(l)); }
310 bool isCompressedLvl(Level l) const { return isCompressedLT(getLvlType(l)); }
313 }
314 bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
315 bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
316 bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
317 bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
318 bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
319 bool isWithCrd(Level l) const { return isWithCrdLT(getLvlType(l)); }
320
321 /// Returns the coordinate-overhead bitwidth, defaulting to zero.
322 unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; }
323
324 /// Returns the position-overhead bitwidth, defaulting to zero.
325 unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
326
327 /// Returns the explicit value, defaulting to null Attribute for unset.
329 return enc ? enc.getExplicitVal() : nullptr;
330 }
331
332 /// Returns the implicit value, defaulting to null Attribute for 0.
334 return enc ? enc.getImplicitVal() : nullptr;
335 }
336
337 /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
338 Type getCrdType() const { return enc.getCrdElemType(); }
339
340 /// Returns the position-overhead MLIR type, defaulting to `IndexType`.
341 Type getPosType() const { return enc.getPosElemType(); }
342
343 /// Returns true iff this sparse tensor type has a trailing
344 /// COO region starting at the given level. By default, it
345 /// tests for a unique COO type at top level.
346 bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
347
348 /// Returns the starting level of this sparse tensor type for a
349 /// trailing COO region that spans **at least** two levels. If
350 /// no such COO region is found, then returns the level-rank.
351 ///
352 /// DEPRECATED: use getCOOSegment instead;
353 Level getAoSCOOStart() const { return getEncoding().getAoSCOOStart(); };
354
355 /// Returns [un]ordered COO type for this sparse tensor type.
356 RankedTensorType getCOOType(bool ordered) const;
357
358 /// Returns a list of COO segments in the sparse tensor types.
360 return getEncoding().getCOOSegments();
361 }
362
363private:
364 // These two must be const, to ensure coherence of the memoized fields.
365 const RankedTensorType rtp;
366 const SparseTensorEncodingAttr enc;
367 // Memoized to avoid frequent redundant conditionals.
368 const Level lvlRank;
369 const AffineMap dimToLvl;
370 const AffineMap lvlToDim;
371};
372
373/// Convenience methods to obtain a SparseTensorType from a Value.
375 return SparseTensorType(cast<RankedTensorType>(val.getType()));
376}
377inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {
378 if (auto rtp = dyn_cast<RankedTensorType>(val.getType()))
379 return SparseTensorType(rtp);
380 return std::nullopt;
381}
382
383} // namespace sparse_tensor
384} // namespace mlir
385
386#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc)
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool operator!=(const SparseTensorType &other) const
unsigned getBatchLvlRank() const
Returns the batched level-rank.
SparseTensorType withoutImplicitVal() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
size_t getNumDynamicDims() const
Returns the number of dimensions which have dynamic sizes.
bool operator==(const SparseTensorType &other) const
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
SparseTensorType withoutDimToLvl() const
SparseTensorType withDimToLvl(AffineMap dimToLvl) const
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
SparseTensorType(const SparseTensorType &)=default
SparseTensorType withDimToLvl(const SparseTensorType &dimToLvlSTT) const
bool isAllDense() const
Returns true for tensors where every level is dense.
SparseTensorType withDimSlices(ArrayRef< SparseTensorDimSliceAttr > dimSlices) const
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
SparseTensorType withImplicitVal(Attribute implicitVal) const
bool hasDynamicDimShape() const
Returns true if any dimension has dynamic size.
SparseTensorType withExplicitVal(Attribute explicitVal) const
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SparseTensorType withoutBitWidths() const
SparseTensorType withoutExplicitVal() const
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
SparseTensorType withoutDimSlices() const
RankedTensorType getDemappedType() const
Returns the type with an identity mapping.
AffineMap getExpandedDimToLvl() const
Returns the dimToLvl mapping, where the identity map is expanded out into a full AffineMap.
SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
SparseTensorType withDimToLvl(SparseTensorEncodingAttr dimToLvlEnc) const
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
ValueRange translateCrds(OpBuilder &builder, Location loc, ValueRange crds, CrdTransDirectionKind dir) const
Translates between level / dimension coordinate space.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
SparseTensorType & operator=(const SparseTensorType &)=delete
SmallVector< COOSegment > getCOOSegments() const
Returns a list of COO segments in the sparse tensor types.
bool isUniqueLT(LevelType lt)
Definition Enums.h:428
bool isWithCrdLT(LevelType lt)
Definition Enums.h:431
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
Definition Enums.h:432
bool isOrderedLT(LevelType lt)
Definition Enums.h:425
bool isSingletonLT(LevelType lt)
Definition Enums.h:421
bool isCompressedLT(LevelType lt)
Definition Enums.h:415
bool isLooseCompressedLT(LevelType lt)
Definition Enums.h:418
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool isDenseLT(LevelType lt)
Definition Enums.h:413
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
bool isNOutOfMLT(LevelType lt)
Definition Enums.h:424
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238