MLIR  22.0.0git
SparseTensorDescriptor.cpp
Go to the documentation of this file.
1 //===- SparseTensorDescriptor.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "CodegenUtils.h"
11 
16 
17 using namespace mlir;
18 using namespace sparse_tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // Private helper methods.
22 //===----------------------------------------------------------------------===//
23 
24 /// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
25 static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
26  std::optional<Level> lvl) {
27  return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
28  : IntegerAttr();
29 }
30 
31 // This is only ever called from `SparseTensorTypeToBufferConverter`,
32 // which is why the first argument is `RankedTensorType` rather than
33 // `SparseTensorType`.
34 static std::optional<LogicalResult>
35 convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
36  const SparseTensorType stt(rtp);
37  if (!stt.hasEncoding())
38  return std::nullopt;
39 
40  unsigned numFields = fields.size();
41  (void)numFields;
43  stt,
44  [&](Type fieldType, FieldIndex fieldIdx,
45  SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
46  LevelType /*lt*/) -> bool {
47  assert(numFields + fieldIdx == fields.size());
48  fields.push_back(fieldType);
49  return true;
50  });
51  return success();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // The sparse tensor type converter (defined in Passes.h).
56 //===----------------------------------------------------------------------===//
57 
58 static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
59  ValueRange inputs, Location loc) {
60  if (!getSparseTensorEncoding(tp))
61  // Not a sparse tensor.
62  return Value();
63  // Sparsifier knows how to cancel out these casts.
64  return genTuple(builder, loc, tp, inputs);
65 }
66 
68  addConversion([](Type type) { return type; });
69  addConversion(convertSparseTensorType);
70 
71  // Required by scf.for 1:N type conversion.
72  addSourceMaterialization(materializeTuple);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // StorageTensorSpecifier methods.
77 //===----------------------------------------------------------------------===//
78 
80  SparseTensorType stt) {
81  return StorageSpecifierInitOp::create(
82  builder, loc, StorageSpecifierType::get(stt.getEncoding()));
83 }
84 
86  StorageSpecifierKind kind,
87  std::optional<Level> lvl) {
88  return GetStorageSpecifierOp::create(
89  builder, loc, specifier, kind,
90  optionalLevelAttr(specifier.getContext(), lvl));
91 }
92 
94  Value v,
95  StorageSpecifierKind kind,
96  std::optional<Level> lvl) {
97  // TODO: make `v` have type `TypedValue<IndexType>` instead.
98  assert(v.getType().isIndex());
99  specifier = SetStorageSpecifierOp::create(
100  builder, loc, specifier, kind,
101  optionalLevelAttr(specifier.getContext(), lvl), v);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // SparseTensorDescriptor methods.
106 //===----------------------------------------------------------------------===//
107 
109  OpBuilder &builder, Location loc, Level lvl) const {
110  const Level cooStart = rType.getAoSCOOStart();
111  if (lvl < cooStart)
112  return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
113 
114  Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
115  Value size = getCrdMemSize(builder, loc, cooStart);
116  size = arith::DivUIOp::create(builder, loc, size, stride);
117  return memref::SubViewOp::create(
118  builder, loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
119  /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
120  /*size=*/ValueRange{size},
121  /*step=*/ValueRange{stride});
122 }
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::optional< LogicalResult > convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl< Type > &fields)
static Value materializeTuple(OpBuilder &builder, RankedTensorType tp, ValueRange inputs, Location loc)
static IntegerAttr optionalLevelAttr(MLIRContext *ctx, std::optional< Level > lvl)
Constructs a nullable LevelAttr from the std::optional<Level>.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorEncodingAttr getEncoding() const
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238