MLIR  20.0.0git
SparseTensorDescriptor.cpp
Go to the documentation of this file.
1 //===- SparseTensorDescriptor.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "CodegenUtils.h"
11 
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Private helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 /// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
26 static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
27  std::optional<Level> lvl) {
28  return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
29  : IntegerAttr();
30 }
31 
32 // This is only ever called from `SparseTensorTypeToBufferConverter`,
33 // which is why the first argument is `RankedTensorType` rather than
34 // `SparseTensorType`.
35 static std::optional<LogicalResult>
36 convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
37  const SparseTensorType stt(rtp);
38  if (!stt.hasEncoding())
39  return std::nullopt;
40 
42  stt,
43  [&fields](Type fieldType, FieldIndex fieldIdx,
44  SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
45  LevelType /*lt*/) -> bool {
46  assert(fieldIdx == fields.size());
47  fields.push_back(fieldType);
48  return true;
49  });
50  return success();
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // The sparse tensor type converter (defined in Passes.h).
55 //===----------------------------------------------------------------------===//
56 
57 static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
58  ValueRange inputs, Location loc) {
59  if (!getSparseTensorEncoding(tp))
60  // Not a sparse tensor.
61  return Value();
62  // Sparsifier knows how to cancel out these casts.
63  return genTuple(builder, loc, tp, inputs);
64 }
65 
67  addConversion([](Type type) { return type; });
68  addConversion(convertSparseTensorType);
69 
70  // Required by scf.for 1:N type conversion.
71  addSourceMaterialization(materializeTuple);
72 
73  // Required as a workaround until we have full 1:N support.
74  addArgumentMaterialization(materializeTuple);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // StorageTensorSpecifier methods.
79 //===----------------------------------------------------------------------===//
80 
82  SparseTensorType stt) {
83  return builder.create<StorageSpecifierInitOp>(
85 }
86 
88  StorageSpecifierKind kind,
89  std::optional<Level> lvl) {
90  return builder.create<GetStorageSpecifierOp>(
91  loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl));
92 }
93 
95  Value v,
96  StorageSpecifierKind kind,
97  std::optional<Level> lvl) {
98  // TODO: make `v` have type `TypedValue<IndexType>` instead.
99  assert(v.getType().isIndex());
100  specifier = builder.create<SetStorageSpecifierOp>(
101  loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // SparseTensorDescriptor methods.
106 //===----------------------------------------------------------------------===//
107 
109  OpBuilder &builder, Location loc, Level lvl) const {
110  const Level cooStart = rType.getAoSCOOStart();
111  if (lvl < cooStart)
112  return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
113 
114  Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
115  Value size = getCrdMemSize(builder, loc, cooStart);
116  size = builder.create<arith::DivUIOp>(loc, size, stride);
117  return builder.create<memref::SubViewOp>(
118  loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
119  /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
120  /*size=*/ValueRange{size},
121  /*step=*/ValueRange{stride});
122 }
static std::optional< LogicalResult > convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl< Type > &fields)
static Value materializeTuple(OpBuilder &builder, RankedTensorType tp, ValueRange inputs, Location loc)
static IntegerAttr optionalLevelAttr(MLIRContext *ctx, std::optional< Level > lvl)
Constructs a nullable LevelAttr from the std::optional<Level>.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorEncodingAttr getEncoding() const
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238