MLIR  19.0.0git
SparseStorageSpecifierToLLVM.cpp
Go to the documentation of this file.
1 //===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
10 
14 
15 #include <optional>
16 
17 using namespace mlir;
18 using namespace sparse_tensor;
19 
20 namespace {
21 
22 //===----------------------------------------------------------------------===//
23 // Helper methods.
24 //===----------------------------------------------------------------------===//
25 
26 static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
27  MLIRContext *ctx = tp.getContext();
28  auto enc = tp.getEncoding();
29  const Level lvlRank = enc.getLvlRank();
30 
31  SmallVector<Type, 4> result;
32  // TODO: how can we get the lowering type for index type in the later pipeline
33  // to be consistent? LLVM::StructureType does not allow index fields.
34  auto sizeType = IntegerType::get(tp.getContext(), 64);
35  auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
36  auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
38  result.push_back(lvlSizes);
39  result.push_back(memSizes);
40 
41  if (enc.isSlice()) {
42  // Extra fields are required for the slice information.
43  auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
44  auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
45 
46  result.push_back(dimOffset);
47  result.push_back(dimStride);
48  }
49 
50  return result;
51 }
52 
53 static Type convertSpecifier(StorageSpecifierType tp) {
54  return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55  getSpecifierFields(tp));
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Specifier struct builder.
60 //===----------------------------------------------------------------------===//
61 
62 constexpr uint64_t kLvlSizePosInSpecifier = 0;
63 constexpr uint64_t kMemSizePosInSpecifier = 1;
64 constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65 constexpr uint64_t kDimStridePosInSpecifier = 3;
66 
67 class SpecifierStructBuilder : public StructBuilder {
68 private:
69  Value extractField(OpBuilder &builder, Location loc,
70  ArrayRef<int64_t> indices) const {
71  return genCast(builder, loc,
72  builder.create<LLVM::ExtractValueOp>(loc, value, indices),
73  builder.getIndexType());
74  }
75 
76  void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
77  Value v) {
78  value = builder.create<LLVM::InsertValueOp>(
79  loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
80  indices);
81  }
82 
83 public:
84  explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
85  assert(value);
86  }
87 
88  // Undef value for dimension sizes, all zero value for memory sizes.
89  static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
90  Value source);
91 
92  Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
93  void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
94 
95  Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
96  void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
97  Value size);
98 
99  Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
100  void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
101  Value size);
102 
103  Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
104  void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
105  Value size);
106 
107  Value memSizeArray(OpBuilder &builder, Location loc) const;
108  void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
109 };
110 
111 Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
112  Type structType, Value source) {
113  Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
114  SpecifierStructBuilder md(metaData);
115  if (!source) {
116  auto memSizeArrayType =
117  cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118  .getBody()[kMemSizePosInSpecifier]);
119 
120  Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
121  // Fill memSizes array with zero.
122  for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123  md.setMemSize(builder, loc, i, zero);
124  } else {
125  // We copy non-slice information (memory sizes array) from source
126  SpecifierStructBuilder sourceMd(source);
127  md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
128  }
129  return md;
130 }
131 
132 /// Builds IR extracting the pos-th offset from the descriptor.
133 Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
134  Dimension dim) const {
135  return extractField(
136  builder, loc,
137  ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
138 }
139 
140 /// Builds IR inserting the pos-th offset into the descriptor.
141 void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
142  Dimension dim, Value size) {
143  insertField(
144  builder, loc,
145  ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
146  size);
147 }
148 
149 /// Builds IR extracting the `lvl`-th level-size from the descriptor.
150 Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
151  Level lvl) const {
152  // This static_cast makes the narrowing of `lvl` explicit, as required
153  // by the braces notation for the ctor.
154  return extractField(
155  builder, loc,
156  ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)});
157 }
158 
159 /// Builds IR inserting the `lvl`-th level-size into the descriptor.
160 void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
161  Level lvl, Value size) {
162  // This static_cast makes the narrowing of `lvl` explicit, as required
163  // by the braces notation for the ctor.
164  insertField(
165  builder, loc,
166  ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)},
167  size);
168 }
169 
170 /// Builds IR extracting the pos-th stride from the descriptor.
171 Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
172  Dimension dim) const {
173  return extractField(
174  builder, loc,
175  ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
176 }
177 
178 /// Builds IR inserting the pos-th stride into the descriptor.
179 void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
180  Dimension dim, Value size) {
181  insertField(
182  builder, loc,
183  ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
184  size);
185 }
186 
187 /// Builds IR extracting the pos-th memory size into the descriptor.
188 Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
189  FieldIndex fidx) const {
190  return extractField(
191  builder, loc,
192  ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
193 }
194 
195 /// Builds IR inserting the `fidx`-th memory-size into the descriptor.
196 void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
197  FieldIndex fidx, Value size) {
198  insertField(
199  builder, loc,
200  ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
201  size);
202 }
203 
204 /// Builds IR extracting the memory size array from the descriptor.
205 Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
206  Location loc) const {
207  return builder.create<LLVM::ExtractValueOp>(loc, value,
208  kMemSizePosInSpecifier);
209 }
210 
211 /// Builds IR inserting the memory size array into the descriptor.
212 void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
213  Value array) {
214  value = builder.create<LLVM::InsertValueOp>(loc, value, array,
215  kMemSizePosInSpecifier);
216 }
217 
218 } // namespace
219 
220 //===----------------------------------------------------------------------===//
221 // The sparse storage specifier type converter (defined in Passes.h).
222 //===----------------------------------------------------------------------===//
223 
225  addConversion([](Type type) { return type; });
226  addConversion(convertSpecifier);
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Storage specifier conversion rules.
231 //===----------------------------------------------------------------------===//
232 
233 template <typename Base, typename SourceOp>
235 public:
236  using OpAdaptor = typename SourceOp::Adaptor;
238 
240  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
241  ConversionPatternRewriter &rewriter) const override {
242  SpecifierStructBuilder spec(adaptor.getSpecifier());
243  switch (op.getSpecifierKind()) {
244  case StorageSpecifierKind::LvlSize: {
245  Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
246  rewriter.replaceOp(op, v);
247  return success();
248  }
249  case StorageSpecifierKind::DimOffset: {
250  Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
251  rewriter.replaceOp(op, v);
252  return success();
253  }
254  case StorageSpecifierKind::DimStride: {
255  Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
256  rewriter.replaceOp(op, v);
257  return success();
258  }
259  case StorageSpecifierKind::CrdMemSize:
260  case StorageSpecifierKind::PosMemSize:
261  case StorageSpecifierKind::ValMemSize: {
262  auto enc = op.getSpecifier().getType().getEncoding();
263  StorageLayout layout(enc);
264  std::optional<unsigned> lvl;
265  if (op.getLevel())
266  lvl = (*op.getLevel());
267  unsigned idx =
268  layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
269  Value v = Base::onMemSize(rewriter, op, spec, idx);
270  rewriter.replaceOp(op, v);
271  return success();
272  }
273  }
274  llvm_unreachable("unrecognized specifer kind");
275  }
276 };
277 
279  : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
280  SetStorageSpecifierOp> {
281  using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
282 
283  static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
284  SpecifierStructBuilder &spec, Level lvl) {
285  spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
286  return spec;
287  }
288 
289  static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
290  SpecifierStructBuilder &spec, Dimension d) {
291  spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
292  return spec;
293  }
294 
295  static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
296  SpecifierStructBuilder &spec, Dimension d) {
297  spec.setDimStride(builder, op.getLoc(), d, op.getValue());
298  return spec;
299  }
300 
301  static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
302  SpecifierStructBuilder &spec, FieldIndex fidx) {
303  spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
304  return spec;
305  }
306 };
307 
309  : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
310  GetStorageSpecifierOp> {
311  using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
312 
313  static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
314  SpecifierStructBuilder &spec, Level lvl) {
315  return spec.lvlSize(builder, op.getLoc(), lvl);
316  }
317 
318  static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
319  const SpecifierStructBuilder &spec, Dimension d) {
320  return spec.dimOffset(builder, op.getLoc(), d);
321  }
322 
323  static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
324  const SpecifierStructBuilder &spec, Dimension d) {
325  return spec.dimStride(builder, op.getLoc(), d);
326  }
327 
328  static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
329  SpecifierStructBuilder &spec, FieldIndex fidx) {
330  return spec.memSize(builder, op.getLoc(), fidx);
331  }
332 };
333 
335  : public OpConversionPattern<StorageSpecifierInitOp> {
336 public:
339  matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
340  ConversionPatternRewriter &rewriter) const override {
341  Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
342  rewriter.replaceOp(
343  op, SpecifierStructBuilder::getInitValue(
344  rewriter, op.getLoc(), llvmType, adaptor.getSource()));
345  return success();
346  }
347 };
348 
349 //===----------------------------------------------------------------------===//
350 // Public method for populating conversion rules.
351 //===----------------------------------------------------------------------===//
352 
354  RewritePatternSet &patterns) {
357  patterns.getContext());
358 }
LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Helper class to produce LLVM dialect operations extracting or inserting values to a struct.
Definition: StructBuilder.h:26
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Provides methods to access fields of a sparse tensor with the given encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind)
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc)
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
LogicalResult matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26