MLIR 22.0.0git
SparseTensorDescriptor.cpp
Go to the documentation of this file.
1//===- SparseTensorDescriptor.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "CodegenUtils.h"
11
16
17using namespace mlir;
18using namespace sparse_tensor;
19
20//===----------------------------------------------------------------------===//
21// Private helper methods.
22//===----------------------------------------------------------------------===//
23
24/// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
25static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
26 std::optional<Level> lvl) {
27 return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
28 : IntegerAttr();
29}
30
31// This is only ever called from `SparseTensorTypeToBufferConverter`,
32// which is why the first argument is `RankedTensorType` rather than
33// `SparseTensorType`.
34static std::optional<LogicalResult>
35convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
36 const SparseTensorType stt(rtp);
37 if (!stt.hasEncoding())
38 return std::nullopt;
39
40 unsigned numFields = fields.size();
41 (void)numFields;
43 stt,
44 [&](Type fieldType, FieldIndex fieldIdx,
45 SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
46 LevelType /*lt*/) -> bool {
47 assert(numFields + fieldIdx == fields.size());
48 fields.push_back(fieldType);
49 return true;
50 });
51 return success();
52}
53
54//===----------------------------------------------------------------------===//
55// The sparse tensor type converter (defined in Passes.h).
56//===----------------------------------------------------------------------===//
57
58static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
59 ValueRange inputs, Location loc) {
61 // Not a sparse tensor.
62 return Value();
63 // Sparsifier knows how to cancel out these casts.
64 return genTuple(builder, loc, tp, inputs);
65}
66
68 addConversion([](Type type) { return type; });
69 addConversion(convertSparseTensorType);
70
71 // Required by scf.for 1:N type conversion.
72 addSourceMaterialization(materializeTuple);
73}
74
75//===----------------------------------------------------------------------===//
76// StorageTensorSpecifier methods.
77//===----------------------------------------------------------------------===//
78
80 SparseTensorType stt) {
81 return StorageSpecifierInitOp::create(
82 builder, loc, StorageSpecifierType::get(stt.getEncoding()));
83}
84
86 StorageSpecifierKind kind,
87 std::optional<Level> lvl) {
88 return GetStorageSpecifierOp::create(
89 builder, loc, specifier, kind,
90 optionalLevelAttr(specifier.getContext(), lvl));
91}
92
94 Value v,
95 StorageSpecifierKind kind,
96 std::optional<Level> lvl) {
97 // TODO: make `v` have type `TypedValue<IndexType>` instead.
98 assert(v.getType().isIndex());
99 specifier = SetStorageSpecifierOp::create(
100 builder, loc, specifier, kind,
101 optionalLevelAttr(specifier.getContext(), lvl), v);
102}
103
104//===----------------------------------------------------------------------===//
105// SparseTensorDescriptor methods.
106//===----------------------------------------------------------------------===//
107
109 OpBuilder &builder, Location loc, Level lvl) const {
110 const Level cooStart = rType.getAoSCOOStart();
111 if (lvl < cooStart)
113
114 Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
115 Value size = getCrdMemSize(builder, loc, cooStart);
116 size = arith::DivUIOp::create(builder, loc, size, stride);
117 return memref::SubViewOp::create(
118 builder, loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
119 /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
120 /*size=*/ValueRange{size},
121 /*step=*/ValueRange{stride});
122}
return success()
static Value materializeTuple(OpBuilder &builder, RankedTensorType tp, ValueRange inputs, Location loc)
static std::optional< LogicalResult > convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl< Type > &fields)
static IntegerAttr optionalLevelAttr(MLIRContext *ctx, std::optional< Level > lvl)
Constructs a nullable LevelAttr from the std::optional<Level>.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorEncodingAttr getEncoding() const
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition Enums.h:238