MLIR  21.0.0git
SparseTensorDescriptor.cpp
Go to the documentation of this file.
1 //===- SparseTensorDescriptor.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "CodegenUtils.h"
11 
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Private helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 /// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
26 static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
27  std::optional<Level> lvl) {
28  return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
29  : IntegerAttr();
30 }
31 
32 // This is only ever called from `SparseTensorTypeToBufferConverter`,
33 // which is why the first argument is `RankedTensorType` rather than
34 // `SparseTensorType`.
35 static std::optional<LogicalResult>
36 convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
37  const SparseTensorType stt(rtp);
38  if (!stt.hasEncoding())
39  return std::nullopt;
40 
41  unsigned numFields = fields.size();
42  (void)numFields;
44  stt,
45  [&](Type fieldType, FieldIndex fieldIdx,
46  SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
47  LevelType /*lt*/) -> bool {
48  assert(numFields + fieldIdx == fields.size());
49  fields.push_back(fieldType);
50  return true;
51  });
52  return success();
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // The sparse tensor type converter (defined in Passes.h).
57 //===----------------------------------------------------------------------===//
58 
59 static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
60  ValueRange inputs, Location loc) {
61  if (!getSparseTensorEncoding(tp))
62  // Not a sparse tensor.
63  return Value();
64  // Sparsifier knows how to cancel out these casts.
65  return genTuple(builder, loc, tp, inputs);
66 }
67 
69  addConversion([](Type type) { return type; });
70  addConversion(convertSparseTensorType);
71 
72  // Required by scf.for 1:N type conversion.
73  addSourceMaterialization(materializeTuple);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // StorageTensorSpecifier methods.
78 //===----------------------------------------------------------------------===//
79 
81  SparseTensorType stt) {
82  return builder.create<StorageSpecifierInitOp>(
84 }
85 
87  StorageSpecifierKind kind,
88  std::optional<Level> lvl) {
89  return builder.create<GetStorageSpecifierOp>(
90  loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl));
91 }
92 
94  Value v,
95  StorageSpecifierKind kind,
96  std::optional<Level> lvl) {
97  // TODO: make `v` have type `TypedValue<IndexType>` instead.
98  assert(v.getType().isIndex());
99  specifier = builder.create<SetStorageSpecifierOp>(
100  loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v);
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // SparseTensorDescriptor methods.
105 //===----------------------------------------------------------------------===//
106 
108  OpBuilder &builder, Location loc, Level lvl) const {
109  const Level cooStart = rType.getAoSCOOStart();
110  if (lvl < cooStart)
111  return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
112 
113  Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
114  Value size = getCrdMemSize(builder, loc, cooStart);
115  size = builder.create<arith::DivUIOp>(loc, size, stride);
116  return builder.create<memref::SubViewOp>(
117  loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
118  /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
119  /*size=*/ValueRange{size},
120  /*step=*/ValueRange{stride});
121 }
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static std::optional< LogicalResult > convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl< Type > &fields)
static Value materializeTuple(OpBuilder &builder, RankedTensorType tp, ValueRange inputs, Location loc)
static IntegerAttr optionalLevelAttr(MLIRContext *ctx, std::optional< Level > lvl)
Constructs a nullable LevelAttr from the std::optional<Level>.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorEncodingAttr getEncoding() const
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238