MLIR  19.0.0git
SparseTensorDescriptor.cpp
Go to the documentation of this file.
1 //===- SparseTensorDescriptor.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "CodegenUtils.h"
11 
17 
18 using namespace mlir;
19 using namespace sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Private helper methods.
23 //===----------------------------------------------------------------------===//
24 
25 /// Constructs a nullable `LevelAttr` from the `std::optional<Level>`.
26 static IntegerAttr optionalLevelAttr(MLIRContext *ctx,
27  std::optional<Level> lvl) {
28  return lvl ? IntegerAttr::get(IndexType::get(ctx), lvl.value())
29  : IntegerAttr();
30 }
31 
32 // This is only ever called from `SparseTensorTypeToBufferConverter`,
33 // which is why the first argument is `RankedTensorType` rather than
34 // `SparseTensorType`.
35 static std::optional<LogicalResult>
36 convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
37  const SparseTensorType stt(rtp);
38  if (!stt.hasEncoding())
39  return std::nullopt;
40 
42  stt,
43  [&fields](Type fieldType, FieldIndex fieldIdx,
44  SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/,
45  LevelType /*lt*/) -> bool {
46  assert(fieldIdx == fields.size());
47  fields.push_back(fieldType);
48  return true;
49  });
50  return success();
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // The sparse tensor type converter (defined in Passes.h).
55 //===----------------------------------------------------------------------===//
56 
58  addConversion([](Type type) { return type; });
59  addConversion(convertSparseTensorType);
60 
61  // Required by scf.for 1:N type conversion.
62  addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
63  ValueRange inputs,
64  Location loc) -> std::optional<Value> {
65  if (!getSparseTensorEncoding(tp))
66  // Not a sparse tensor.
67  return std::nullopt;
68  // Sparsifier knows how to cancel out these casts.
69  return genTuple(builder, loc, tp, inputs);
70  });
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // StorageTensorSpecifier methods.
75 //===----------------------------------------------------------------------===//
76 
78  SparseTensorType stt) {
79  return builder.create<StorageSpecifierInitOp>(
81 }
82 
84  StorageSpecifierKind kind,
85  std::optional<Level> lvl) {
86  return builder.create<GetStorageSpecifierOp>(
87  loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl));
88 }
89 
91  Value v,
92  StorageSpecifierKind kind,
93  std::optional<Level> lvl) {
94  // TODO: make `v` have type `TypedValue<IndexType>` instead.
95  assert(v.getType().isIndex());
96  specifier = builder.create<SetStorageSpecifierOp>(
97  loc, specifier, kind, optionalLevelAttr(specifier.getContext(), lvl), v);
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // SparseTensorDescriptor methods.
102 //===----------------------------------------------------------------------===//
103 
105  OpBuilder &builder, Location loc, Level lvl) const {
106  const Level cooStart = rType.getAoSCOOStart();
107  if (lvl < cooStart)
108  return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
109 
110  Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart);
111  Value size = getCrdMemSize(builder, loc, cooStart);
112  size = builder.create<arith::DivUIOp>(loc, size, stride);
113  return builder.create<memref::SubViewOp>(
114  loc, getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart),
115  /*offset=*/ValueRange{constantIndex(builder, loc, lvl - cooStart)},
116  /*size=*/ValueRange{size},
117  /*step=*/ValueRange{stride});
118 }
static std::optional< LogicalResult > convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl< Type > &fields)
static IntegerAttr optionalLevelAttr(MLIRContext *ctx, std::optional< Level > lvl)
Constructs a nullable LevelAttr from the std::optional<Level>.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const
void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional< Level > lvl)
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl)
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
SparseTensorEncodingAttr getEncoding() const
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238