MLIR  22.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library (other than for file I/O), and providing many more
14 // opportunities for subsequent compiler optimization of the generated code.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "Utils/CodegenUtils.h"
20 
32 
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Flatten the given value ranges into a single vector of values.
44  SmallVector<Value> result;
45  for (const auto &vals : values)
46  llvm::append_range(result, vals);
47  return result;
48 }
49 
50 /// Generates a load with proper `index` typing.
51 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
52  idx = genCast(builder, loc, idx, builder.getIndexType());
53  return memref::LoadOp::create(builder, loc, mem, idx);
54 }
55 
56 /// Generates a store with proper `index` typing and proper value.
57 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
58  Value idx) {
59  idx = genCast(builder, loc, idx, builder.getIndexType());
60  val = genCast(builder, loc, val,
61  cast<ShapedType>(mem.getType()).getElementType());
62  memref::StoreOp::create(builder, loc, val, mem, idx);
63 }
64 
65 /// Creates a straightforward counting for-loop.
66 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
68  Value lower = Value()) {
69  Type indexType = builder.getIndexType();
70  if (!lower)
71  lower = constantZero(builder, loc, indexType);
72  Value one = constantOne(builder, loc, indexType);
73  scf::ForOp forOp =
74  scf::ForOp::create(builder, loc, lower, upper, one, fields);
75  for (unsigned i = 0, e = fields.size(); i < e; i++)
76  fields[i] = forOp.getRegionIterArg(i);
77  builder.setInsertionPointToStart(forOp.getBody());
78  return forOp;
79 }
80 
81 /// Creates a push back operation.
82 static void createPushback(OpBuilder &builder, Location loc,
84  SparseTensorFieldKind kind, std::optional<Level> lvl,
85  Value value, Value repeat = Value()) {
86  Type etp = desc.getMemRefElementType(kind, lvl);
87  Value field = desc.getMemRefField(kind, lvl);
88  StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
89 
90  auto pushBackOp = PushBackOp::create(
91  builder, loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl),
92  field, genCast(builder, loc, value, etp), repeat);
93 
94  desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
95  desc.setSpecifierField(builder, loc, specFieldKind, lvl,
96  pushBackOp.getNewSize());
97 }
98 
99 /// Generates code that allocates a sparse storage scheme for given rank.
100 static void allocSchemeForRank(OpBuilder &builder, Location loc,
101  MutSparseTensorDescriptor desc, Level startLvl) {
102  const SparseTensorType stt(desc.getRankedTensorType());
103  Value linear = constantIndex(builder, loc, 1);
104  const Level lvlRank = stt.getLvlRank();
105  for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
106  const auto lt = stt.getLvlType(lvl);
107  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
108  // Append linear x positions, initialized to zero. Since each compressed
109  // dimension initially already has a single zero entry, this maintains
110  // the desired "linear + 1" length property at all times. For loose
111  // compression, we multiply linear by two in order to append both the
112  // lo/hi positions.
113  Value posZero = constantZero(builder, loc, stt.getPosType());
114  if (isLooseCompressedLT(lt)) {
115  Value two = constantIndex(builder, loc, 2);
116  linear = arith::MulIOp::create(builder, loc, linear, two);
117  }
118  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
119  /*value=*/posZero, /*repeat=*/linear);
120  return;
121  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
122  return; // nothing to do
123  }
124  // Keep compounding the size, but nothing needs to be initialized
125  // at this level. We will eventually reach a compressed level or
126  // otherwise the values array for the from-here "all-dense" case.
127  assert(isDenseLT(lt));
128  Value size = desc.getLvlSize(builder, loc, lvl);
129  linear = arith::MulIOp::create(builder, loc, linear, size);
130  }
131  // Reached values array so prepare for an insertion.
132  Value valZero = constantZero(builder, loc, stt.getElementType());
134  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
135 }
136 
137 /// Creates allocation operation.
139  MemRefType memRefType, Value sz,
140  bool enableInit) {
141  Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
142  Type elemType = memRefType.getElementType();
143  if (enableInit) {
144  Value fillValue = constantZero(builder, loc, elemType);
145  linalg::FillOp::create(builder, loc, fillValue, buffer);
146  }
147  return buffer;
148 }
149 
150 /// Creates the dim sizes array, filling in from dynamic sizes.
151 static void createDimSizes(OpBuilder &builder, Location loc,
152  SparseTensorType stt, ValueRange dynSizes,
153  /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
154  const Dimension dimRank = stt.getDimRank();
155  dimSizesValues.clear();
156  dimSizesValues.reserve(dimRank);
157  unsigned i = 0;
158  for (const Size sz : stt.getDimShape())
159  dimSizesValues.push_back(ShapedType::isDynamic(sz)
160  ? dynSizes[i++]
161  : constantIndex(builder, loc, sz));
162 }
163 
164 /// Creates allocation for each field in sparse tensor type. Note that
165 /// for all dynamic memrefs in the sparse tensor stroage layout, the
166 /// memory size is really the capacity of the "vector", while the actual
167 /// size resides in the sizes array.
168 static void createAllocFields(OpBuilder &builder, Location loc,
169  SparseTensorType stt, bool enableInit,
170  Value sizeHint,
171  SmallVectorImpl<Value> &lvlSizesValues,
172  /*out*/ SmallVectorImpl<Value> &fields) {
173  Level lvlRank = stt.getLvlRank();
174  // Set up some heuristic sizes. We try to set the initial
175  // size based on available information. Otherwise we just
176  // initialize a few elements to start the reallocation chain.
177  // TODO: refine this
178  Value posHeuristic, crdHeuristic, valHeuristic;
179  if (stt.isAllDense()) {
180  valHeuristic = lvlSizesValues[0];
181  for (Level lvl = 1; lvl < lvlRank; lvl++)
182  valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
183  lvlSizesValues[lvl]);
184  } else if (sizeHint) {
185  if (stt.getAoSCOOStart() == 0) {
186  posHeuristic = constantIndex(builder, loc, 2);
187  crdHeuristic = arith::MulIOp::create(
188  builder, loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
189  } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
190  posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
191  constantIndex(builder, loc, 1));
192  crdHeuristic = sizeHint;
193  } else {
194  posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
195  }
196  valHeuristic = sizeHint;
197  } else {
198  posHeuristic = crdHeuristic = valHeuristic =
199  constantIndex(builder, loc, 16);
200  }
201  // Initializes all fields. An initial storage specifier and allocated
202  // positions/coordinates/values memrefs (with heuristic capacity).
204  stt,
205  [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
206  enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
207  Level /*lvl*/, LevelType /*lt*/) -> bool {
208  assert(fields.size() == fIdx);
209  Value field;
210  switch (fKind) {
212  field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
213  break;
215  field = createAllocation(builder, loc, cast<MemRefType>(fType),
216  posHeuristic, enableInit);
217  break;
219  field = createAllocation(builder, loc, cast<MemRefType>(fType),
220  crdHeuristic, enableInit);
221  break;
223  field = createAllocation(builder, loc, cast<MemRefType>(fType),
224  valHeuristic, enableInit);
225  break;
226  }
227  assert(field);
228  fields.push_back(field);
229  // Returns true to continue the iteration.
230  return true;
231  });
232  // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
233  // and gives all position fields an initial zero entry, so that it is
234  // easier to maintain the "linear + 1" length property.
235  MutSparseTensorDescriptor desc(stt, fields);
236  Value posZero = constantZero(builder, loc, stt.getPosType());
237  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
238  desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
239  const auto lt = stt.getLvlType(lvl);
240  if (isCompressedLT(lt) || isLooseCompressedLT(lt))
241  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
242  /*value=*/posZero);
243  }
244  allocSchemeForRank(builder, loc, desc, /*rank=*/0);
245 }
246 
247 /// Helper method that generates block specific to compressed case:
248 ///
249 /// // given: parentPos = posCursor[lvl-1]
250 /// pstart = desc.positions[lvl][parentPos]
251 /// pstop = desc.positions[lvl][parentPos+1]
252 /// plast = pstop - 1
253 /// msz = desc.coordinates[lvl].size()
254 /// if (pstart < pstop) {
255 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
256 /// } else { // first insertion
257 /// isPresent = false
258 /// desc.positions[lvl][parentPos] = msz
259 /// }
260 /// if (isPresent) { // coordinate is already present
261 /// pnext = plast
262 /// } else {
263 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
264 /// desc.positions[lvl][parentPos+1] = msz+1
265 /// pnext = msz
266 /// <prepare level lvl+1>
267 /// }
268 /// posCursor[lvl] = pnext
269 static Value genCompressed(OpBuilder &builder, Location loc,
270  MutSparseTensorDescriptor desc, ValueRange lvlCoords,
271  Value /*unused*/, Value parentPos, Level lvl) {
272  const SparseTensorType stt(desc.getRankedTensorType());
273  const Level lvlRank = stt.getLvlRank();
274  assert(lvl < lvlRank && "Level is out of bounds");
275  assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
276  "Level-rank mismatch");
277  SmallVector<Type> types;
278  Type indexType = builder.getIndexType();
279  Type boolType = builder.getIntegerType(1);
280  unsigned crdFidx;
281  unsigned crdStride;
282  std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
283  const Value one = constantIndex(builder, loc, 1);
284  const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
285  const Value positionsAtLvl = desc.getPosMemRef(lvl);
286  const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
287  const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
288  const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
289  const Value crdStrideC =
290  crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
291  const Value msz =
292  crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
293  : crdMsz;
294  const Value plast = arith::SubIOp::create(
295  builder, loc, genCast(builder, loc, pstop, indexType), one);
296  // Conditional expression.
297  Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
298  pstart, pstop);
299  types.push_back(boolType);
300  scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt, /*else*/ true);
301  types.pop_back();
302  builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
303  Value crd = genLoad(
304  builder, loc, desc.getMemRefField(crdFidx),
305  crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
306  : plast);
307  Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
308  genCast(builder, loc, crd, indexType),
309  lvlCoords[lvl]);
310  scf::YieldOp::create(builder, loc, eq);
311  builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
312  if (lvl > 0)
313  genStore(builder, loc, msz, positionsAtLvl, parentPos);
314  scf::YieldOp::create(builder, loc, constantI1(builder, loc, false));
315  builder.setInsertionPointAfter(ifOp1);
316  // If present construct. Note that for a non-unique dimension level, we
317  // simply set the condition to false and rely on CSE/DCE to clean up the IR.
318  //
319  // TODO: generate less temporary IR?
320  //
321  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
322  types.push_back(desc.getField(i).getType());
323  types.push_back(indexType);
324  const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
325  : constantI1(builder, loc, false);
326  scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p, /*else*/ true);
327  // If present (fields unaffected, update pnext to plast).
328  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
329 
330  // FIXME: This does not looks like a clean way, but probably the most
331  // efficient way.
332  desc.getFields().push_back(plast);
333  scf::YieldOp::create(builder, loc, desc.getFields());
334  desc.getFields().pop_back();
335 
336  // If !present (changes fields, update pnext).
337  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
338  Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
339  genStore(builder, loc, mszp1, positionsAtLvl, pp1);
340  createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
341  /*value=*/lvlCoords[lvl]);
342  // Prepare the next level "as needed".
343  if ((lvl + 1) < lvlRank)
344  allocSchemeForRank(builder, loc, desc, lvl + 1);
345 
346  desc.getFields().push_back(msz);
347  scf::YieldOp::create(builder, loc, desc.getFields());
348  desc.getFields().pop_back();
349 
350  // Update fields and return next pos.
351  builder.setInsertionPointAfter(ifOp2);
352  unsigned o = 0;
353  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
354  desc.setField(i, ifOp2.getResult(o++));
355  return ifOp2.getResult(o);
356 }
357 
358 /// Generates insertion finalization code.
359 static void genEndInsert(OpBuilder &builder, Location loc,
360  SparseTensorDescriptor desc) {
361  const SparseTensorType stt(desc.getRankedTensorType());
362  const Level lvlRank = stt.getLvlRank();
363  for (Level lvl = 0; lvl < lvlRank; lvl++) {
364  const auto lt = stt.getLvlType(lvl);
365  if (isCompressedLT(lt)) {
366  // Compressed dimensions need a position cleanup for all entries
367  // that were not visited during the insertion pass.
368  //
369  // TODO: avoid cleanup and keep compressed scheme consistent at all
370  // times?
371  //
372  if (lvl > 0) {
373  Type posType = stt.getPosType();
374  Value posMemRef = desc.getPosMemRef(lvl);
375  Value hi = desc.getPosMemSize(builder, loc, lvl);
376  Value zero = constantIndex(builder, loc, 0);
377  Value one = constantIndex(builder, loc, 1);
378  // Vector of only one, but needed by createFor's prototype.
379  SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
380  scf::ForOp loop = createFor(builder, loc, hi, inits, one);
381  Value i = loop.getInductionVar();
382  Value oldv = loop.getRegionIterArg(0);
383  Value newv = genLoad(builder, loc, posMemRef, i);
384  Value posZero = constantZero(builder, loc, posType);
385  Value cond = arith::CmpIOp::create(
386  builder, loc, arith::CmpIPredicate::eq, newv, posZero);
387  scf::IfOp ifOp = scf::IfOp::create(builder, loc, TypeRange(posType),
388  cond, /*else*/ true);
389  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
390  genStore(builder, loc, oldv, posMemRef, i);
391  scf::YieldOp::create(builder, loc, oldv);
392  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
393  scf::YieldOp::create(builder, loc, newv);
394  builder.setInsertionPointAfter(ifOp);
395  scf::YieldOp::create(builder, loc, ifOp.getResult(0));
396  builder.setInsertionPointAfter(loop);
397  }
398  } else {
399  assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
400  isNOutOfMLT(lt));
401  }
402  }
403 }
404 
405 /// Generates a subview into the sizes.
406 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
407  Value sz) {
408  auto memTp = llvm::cast<MemRefType>(mem.getType());
409  // For higher-dimensional memrefs, we assume that the innermost
410  // dimension is always of the right size.
411  // TODO: generate complex truncating view here too?
412  if (memTp.getRank() > 1)
413  return mem;
414  // Truncate linear memrefs to given size.
415  return memref::SubViewOp::create(
416  builder, loc,
417  MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
418  mem, ValueRange{}, ValueRange{sz}, ValueRange{},
419  ArrayRef<int64_t>{0}, // static offset
420  ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
421  ArrayRef<int64_t>{1}) // static stride
422  .getResult();
423 }
424 
425 /// Creates the reassociation array.
427 getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
428  SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
429  // Create reassociation in the form:
430  // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
431  for (unsigned i = 0; i < batchLvls; i++)
432  ret[i].push_back(i);
433 
434  for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
435  ret.back().push_back(i);
436 
437  return ret;
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // Codegen rules.
442 //===----------------------------------------------------------------------===//
443 
444 namespace {
445 
446 /// Helper class to help lowering sparse_tensor.insert operation.
447 class SparseInsertGenerator
448  : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
449 public:
450  SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
451  bool genCall)
452  : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
453 
454  /// Generates code along an insertion path without the need for a "cursor".
455  /// This current insertion strategy comes at the expense of some testing
456  /// overhead for each insertion. The strategy will be optimized later for
457  /// common insertion patterns. The current insertion strategy also assumes
458  /// insertions occur in "a reasonable order" that enables building the
459  /// storage scheme in an appending/inserting kind of fashion (i.e. no
460  /// in-between insertions that need data movement). The implementation
461  /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
462  ///
463  /// TODO: better unord/not-unique; also generalize, optimize, specialize!
464  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
465  OpBuilder &builder, Location loc) {
466  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
467  const Level lvlRank = stt.getLvlRank();
468  // Extract fields and coordinates from args.
469  SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
470  MutSparseTensorDescriptor desc(stt, fields);
471  const SmallVector<Value> coords =
472  llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
473  Value value = args.back();
474  Value parentPos = constantZero(builder, loc, builder.getIndexType());
475  // Generate code for every level.
476  for (Level lvl = 0; lvl < lvlRank; lvl++) {
477  const auto lt = stt.getLvlType(lvl);
478  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
479  // Create:
480  // if (!present) {
481  // coordinates[lvl].push_back(coords[lvl])
482  // <update positions and prepare level lvl + 1>
483  // }
484  // positions[lvl] = coordinates.size() - 1
485  // <insert @ positions[lvl] at next level lvl + 1>
486  if (isLooseCompressedLT(lt)) {
487  Value two = constantIndex(builder, loc, 2);
488  parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
489  }
490  parentPos =
491  genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
492  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
493  // Create:
494  // coordinates[lvl].push_back(coords[lvl])
495  // positions[lvl] = positions[lvl-1]
496  // <insert @ positions[lvl] at next level lvl + 1>
498  lvl, /*value=*/coords[lvl]);
499  } else {
500  assert(isDenseLT(lt));
501  // Construct the new position as:
502  // positions[lvl] = size * positions[lvl-1] + coords[lvl]
503  // <insert @ positions[lvl] at next level lvl + 1>
504  Value size = desc.getLvlSize(builder, loc, lvl);
505  Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
506  parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
507  }
508  }
509  // Reached the actual value append/insert.
510  if (!stt.isDenseLvl(lvlRank - 1))
512  std::nullopt, value);
513  else
514  genStore(builder, loc, value, desc.getValMemRef(), parentPos);
515  return fields;
516  }
517 
518  std::string getMangledFuncName() {
519  // The mangled name of the function has this format:
520  // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
521  constexpr const char kInsertFuncNamePrefix[] = "_insert_";
522  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
523  SmallString<32> nameBuffer;
524  llvm::raw_svector_ostream nameOstream(nameBuffer);
525  nameOstream << kInsertFuncNamePrefix;
526  const Level lvlRank = stt.getLvlRank();
527  for (Level l = 0; l < lvlRank; l++) {
528  std::string lvlType = toMLIRString(stt.getLvlType(l));
529  // Replace/remove punctuations in level properties.
530  std::replace_if(
531  lvlType.begin(), lvlType.end(),
532  [](char c) { return c == '(' || c == ','; }, '_');
533  llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
534  nameOstream << lvlType << "_";
535  }
536  // Static dim sizes are used in the generated code while dynamic sizes are
537  // loaded from the dimSizes buffer. This is the reason for adding the shape
538  // to the function name.
539  for (const auto sz : stt.getDimShape())
540  nameOstream << sz << "_";
541  // Permutation information is also used in generating insertion.
542  if (!stt.isIdentity())
543  nameOstream << stt.getDimToLvl() << "_";
544  nameOstream << stt.getElementType() << "_";
545  nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
546  return nameOstream.str().str();
547  }
548 
549 private:
550  TensorType rtp;
551 };
552 
553 /// Sparse tensor storage conversion rule for returns.
554 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
555 public:
557  LogicalResult
558  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
559  ConversionPatternRewriter &rewriter) const override {
560  // Create a return with the flattened value extracted from sparse tensors.
561  rewriter.replaceOpWithNewOp<func::ReturnOp>(
562  op, flattenValues(adaptor.getOperands()));
563  return success();
564  }
565 };
566 
567 /// Sparse tensor storage conversion rule for calls.
568 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
569 public:
570  // The default CallOp converter can not handle 1:N type conversion.
572  LogicalResult
573  matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
574  ConversionPatternRewriter &rewriter) const override {
575  Location loc = op.getLoc();
576  // In case of:
577  // sparse_tensor, f, sparse_tensor = call @foo(...)
578  // ==>
579  // memref..., f, memref = call @foo(...) replace with
580  // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
581  SmallVector<Type> finalRetTy;
582  if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
583  return failure();
584 
585  // (1) Generates new call with flattened return value.
586  auto newCall =
587  func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
588  flattenValues(adaptor.getOperands()));
589  // (2) Gather sparse tensor returns.
590  SmallVector<SmallVector<Value>> packedResultVals;
591  // Tracks the offset of current return value (of the original call)
592  // relative to the new call (after sparse tensor flattening);
593  unsigned retOffset = 0;
594  // Temporal buffer to hold the flattened list of type for
595  // a sparse tensor.
596  SmallVector<Type> sparseFlat;
597  for (auto ret : op.getResults()) {
598  assert(retOffset < newCall.getNumResults());
599  auto retType = ret.getType();
600  if (failed(typeConverter->convertType(retType, sparseFlat)))
601  llvm_unreachable("Failed to convert type in sparse tensor codegen");
602 
603  // Converted types can not be empty when the type conversion succeed.
604  assert(!sparseFlat.empty());
605  if (sparseFlat.size() > 1) {
606  auto flatSize = sparseFlat.size();
607  packedResultVals.emplace_back();
608  llvm::append_range(packedResultVals.back(),
609  newCall.getResults().slice(retOffset, flatSize));
610  retOffset += flatSize;
611  } else {
612  // If this is an 1:1 conversion, no need for casting.
613  packedResultVals.emplace_back();
614  packedResultVals.back().push_back(newCall.getResult(retOffset));
615  retOffset++;
616  }
617  sparseFlat.clear();
618  }
619 
620  assert(packedResultVals.size() == op.getNumResults());
621  rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
622  return success();
623  }
624 };
625 
626 /// Sparse codegen rule for level accesses.
627 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
628 public:
630  LogicalResult
631  matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  std::optional<int64_t> lvl = op.getConstantLvlIndex();
634  RankedTensorType srcType = op.getSource().getType();
635  if (!lvl || !getSparseTensorEncoding(srcType))
636  return failure();
637 
638  auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
639  auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
640 
641  rewriter.replaceOp(op, sz);
642  return success();
643  }
644 };
645 
646 // TODO: use a new SortCOO operation here instead of reusing convert op.
647 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
649  LogicalResult
650  matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
651  ConversionPatternRewriter &rewriter) const override {
652  Location loc = op.getLoc();
653  MLIRContext *ctx = op.getContext();
654 
655  SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
656  SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
657 
658  // Should have been verified.
659  assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
660  dstStt.isCOOType() && srcStt.isCOOType());
661  assert(dstStt.hasSameDimToLvl(srcStt));
662 
663  // We don't need a mutable descriptor here as we perform sorting in-place.
664  auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
665  op.getInputCoo().getType());
666  auto nnz = desc.getValMemSize(rewriter, op.getLoc());
667  auto crd = desc.getAOSMemRef();
668  auto val = desc.getValMemRef();
669 
670  // Otherwise we need another data shuffle and a non-identity map.
671  assert(dstStt.hasSameDimToLvl(srcStt));
672  (void)dstStt; // to silence warning when assertion is disabled
673 
674  auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
675 
676  SortOp::create(rewriter, loc, nnz, crd, ValueRange{val}, id,
677  rewriter.getIndexAttr(0), op.getAlgorithm());
678 
679  // Since we do in-place sorting, the destinate tensor will have the same set
680  // of memrefs as the source tensor.
681  rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
682  return success();
683  }
684 };
685 
686 template <typename Op, StorageSpecifierKind kind>
687 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
688 public:
691 
692  LogicalResult
693  matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
694  ConversionPatternRewriter &rewriter) const override {
695  // Simply lowers to specifer.get <field> operation.
696  auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
697  op.getSlice().getType());
698  auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
699  op.getDim().getZExtValue());
700 
701  rewriter.replaceOp(op, v);
702  return success();
703  }
704 };
705 
706 /// Sparse codegen rule for trivial tensor casts.
707 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
708 public:
710  LogicalResult
711  matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
712  ConversionPatternRewriter &rewriter) const override {
713  // Only rewrite identically annotated source/dest.
714  auto encDst = getSparseTensorEncoding(op.getType());
715  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
716  if (!encDst || encDst != encSrc)
717  return failure();
718  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
719  return success();
720  }
721 };
722 
723 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
724 public:
726  LogicalResult
727  matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
728  ConversionPatternRewriter &rewriter) const override {
729  // Simply fold the operation.
730  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
731  return success();
732  }
733 };
734 
735 /// Sparse codegen rule for the alloc operator.
736 class SparseTensorAllocConverter
737  : public OpConversionPattern<bufferization::AllocTensorOp> {
738 public:
740  SparseTensorAllocConverter(const TypeConverter &typeConverter,
741  MLIRContext *context, bool enableInit)
742  : OpConversionPattern(typeConverter, context),
743  enableBufferInitialization(enableInit) {}
744 
745  LogicalResult
746  matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
747  ConversionPatternRewriter &rewriter) const override {
748  const auto resType = getSparseTensorType(op);
749  if (!resType.hasEncoding())
750  return failure();
751 
752  Location loc = op.getLoc();
753  // Deal with copy.
754  if (op.getCopy()) {
755  auto desc = getDescriptorFromTensorTuple(
756  adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
757  SmallVector<Value> fields;
758  fields.reserve(desc.getNumFields());
759  // Memcpy on memref fields.
760  for (auto field : desc.getMemRefFields()) {
761  auto memrefTp = cast<MemRefType>(field.getType());
762  auto size = memref::DimOp::create(rewriter, loc, field, 0);
763  auto copied =
764  memref::AllocOp::create(rewriter, loc, memrefTp, ValueRange{size});
765  memref::CopyOp::create(rewriter, loc, field, copied);
766  fields.push_back(copied);
767  }
768  // Reuses specifier.
769  fields.push_back(desc.getSpecifier());
770  assert(fields.size() == desc.getNumFields());
771  rewriter.replaceOpWithMultiple(op, {fields});
772  return success();
773  }
774 
775  if (!resType.isIdentity()) {
776  return rewriter.notifyMatchFailure(
777  op, "try run --sparse-reinterpret-map before codegen");
778  }
779  // Level size equals to dimension size since lvl2dim map is an identity map.
780  SmallVector<Value> lvlSizesValues;
781  createDimSizes(rewriter, loc, resType,
782  flattenValues(adaptor.getDynamicSizes()),
783  /*dimSizesValues=*/lvlSizesValues);
784 
785  // Construct allocation for each field.
786  Value sizeHint = op.getSizeHint();
787  SmallVector<Value> fields;
788  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
789  sizeHint, lvlSizesValues, fields);
790 
791  // Replace operation with resulting memrefs.
792  rewriter.replaceOpWithMultiple(op, {fields});
793  return success();
794  }
795 
796 private:
797  bool enableBufferInitialization;
798 };
799 
800 /// Sparse codegen rule for the empty tensor operator.
801 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
802 public:
804  SparseTensorEmptyConverter(const TypeConverter &typeConverter,
805  MLIRContext *context, bool enableInit)
806  : OpConversionPattern(typeConverter, context),
807  enableBufferInitialization(enableInit) {}
808 
809  LogicalResult
810  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
811  ConversionPatternRewriter &rewriter) const override {
812  const auto resType = getSparseTensorType(op);
813  if (!resType.hasEncoding())
814  return failure();
815 
816  if (!resType.isIdentity()) {
817  return rewriter.notifyMatchFailure(
818  op, "try run --sparse-reinterpret-map before codegen");
819  }
820 
821  Location loc = op.getLoc();
822  // Level size equals to dimension size since lvl2dim map is an identity map.
823  SmallVector<Value> lvlSizesValues;
824  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
825  /*dimSizesValues=*/lvlSizesValues);
826  // Construct allocation for each field.
827  Value sizeHint; // none
828  SmallVector<Value> fields;
829  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
830  sizeHint, lvlSizesValues, fields);
831 
832  // Replace operation with resulting memrefs.
833  rewriter.replaceOpWithMultiple(op, {fields});
834  return success();
835  }
836 
837 private:
838  bool enableBufferInitialization;
839 };
840 
841 /// Sparse codegen rule for the dealloc operator.
842 class SparseTensorDeallocConverter
843  : public OpConversionPattern<bufferization::DeallocTensorOp> {
844 public:
846  SparseTensorDeallocConverter(const TypeConverter &typeConverter,
847  MLIRContext *context, bool createDeallocs)
848  : OpConversionPattern(typeConverter, context),
849  createDeallocs(createDeallocs) {}
850 
851  LogicalResult
852  matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
853  ConversionPatternRewriter &rewriter) const override {
854  auto enc = getSparseTensorEncoding(op.getTensor().getType());
855  if (!enc)
856  return failure();
857 
858  // If user requests not to deallocate sparse tensors, simply erase the
859  // operation.
860  if (createDeallocs) {
861  // Replace the sparse tensor deallocation with field deallocations.
862  Location loc = op.getLoc();
863  auto desc = getDescriptorFromTensorTuple(
864  adaptor.getTensor(),
865  cast<RankedTensorType>(op.getTensor().getType()));
866  for (auto input : desc.getMemRefFields())
867  // Deallocate every buffer used to store the sparse tensor handler.
868  memref::DeallocOp::create(rewriter, loc, input);
869  }
870  rewriter.eraseOp(op);
871  return success();
872  }
873 
874 private:
875  const bool createDeallocs;
876 };
877 
878 /// Sparse codegen rule for tensor rematerialization.
879 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
880 public:
882  LogicalResult
883  matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
884  ConversionPatternRewriter &rewriter) const override {
885  // Prepare descriptor.
886  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
887  op.getTensor().getType());
888  // Generate optional insertion finalization code.
889  if (op.getHasInserts())
890  genEndInsert(rewriter, op.getLoc(), desc);
891  // Replace operation with resulting memrefs.
892  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
893  return success();
894  }
895 };
896 
897 /// Sparse codegen rule for the expand op.
898 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
899 public:
901  LogicalResult
902  matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
903  ConversionPatternRewriter &rewriter) const override {
904  if (!getSparseTensorEncoding(op.getTensor().getType()))
905  return failure();
906  Location loc = op->getLoc();
907  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
908  op.getTensor().getType());
909  const auto srcType = getSparseTensorType(op.getTensor());
910  Type eltType = srcType.getElementType();
911  Type boolType = rewriter.getIntegerType(1);
912  Type idxType = rewriter.getIndexType();
913  // All initialization should be done on entry of the loop nest.
914  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
915 
916  // Determine the size for access expansion (always the innermost stored
917  // level size).
918  const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
919  // Generate a memref for `sz` elements of type `t`.
920  const auto genAlloc = [&](Type t) {
921  const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
922  return memref::AllocOp::create(rewriter, loc, memTp, ValueRange{sz});
923  };
924  // Allocate temporary buffers for values/filled-switch and added.
925  // We do not use stack buffers for this, since the expanded size may
926  // be rather large (as it envelops a single expanded dense dimension).
927  Value values = genAlloc(eltType);
928  Value filled = genAlloc(boolType);
929  Value added = genAlloc(idxType);
930  Value zero = constantZero(rewriter, loc, idxType);
931  // Reset the values/filled-switch to all-zero/false. Note that this
932  // introduces an O(N) operation into the computation, but this reset
933  // operation is amortized over the innermost loops for the access
934  // pattern expansion. As noted in the operation doc, we would like
935  // to amortize this setup cost even between kernels.
936  linalg::FillOp::create(rewriter, loc,
937  ValueRange{constantZero(rewriter, loc, eltType)},
938  ValueRange{values});
939  linalg::FillOp::create(rewriter, loc,
940  ValueRange{constantZero(rewriter, loc, boolType)},
941  ValueRange{filled});
942  // Replace expansion op with these buffers and initial coordinate.
943  assert(op.getNumResults() == 4);
944  rewriter.replaceOp(op, {values, filled, added, zero});
945  return success();
946  }
947 };
948 
949 /// Sparse codegen rule for the compress operator.
950 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
951 public:
953  LogicalResult
954  matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
955  ConversionPatternRewriter &rewriter) const override {
956  Location loc = op->getLoc();
957  SmallVector<Value> fields;
958  auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
959  op.getTensor().getType());
960  Value values = llvm::getSingleElement(adaptor.getValues());
961  Value filled = llvm::getSingleElement(adaptor.getFilled());
962  Value added = llvm::getSingleElement(adaptor.getAdded());
963  Value count = llvm::getSingleElement(adaptor.getCount());
964  const SparseTensorType dstType(desc.getRankedTensorType());
965  Type eltType = dstType.getElementType();
966 
967  // If the innermost level is ordered, we need to sort the coordinates
968  // in the "added" array prior to applying the compression.
969  if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
970  SortOp::create(rewriter, loc, count, added, ValueRange{},
971  rewriter.getMultiDimIdentityMap(1),
972  rewriter.getIndexAttr(0),
973  SparseTensorSortKind::HybridQuickSort);
974  // While performing the insertions, we also need to reset the elements
975  // of the values/filled-switch by only iterating over the set elements,
976  // to ensure that the runtime complexity remains proportional to the
977  // sparsity of the expanded access pattern.
978  //
979  // Generate
980  // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
981  // crd = added[i];
982  // value = values[crd];
983  // insert({lvlCoords, crd}, value);
984  // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
985  // values[crd] = 0;
986  // filled[crd] = false;
987  // yield new_memrefs
988  // }
989  scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
990  Value i = loop.getInductionVar();
991 
992  Value crd = genLoad(rewriter, loc, added, i);
993  Value value = genLoad(rewriter, loc, values, crd);
994  SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
995  SmallVector<Type> flatSpTensorTps = llvm::to_vector(
996  llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
997  SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
998  params.append(flatLvlCoords.begin(), flatLvlCoords.end());
999  params.push_back(crd);
1000  params.push_back(value);
1001  SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1002  params, /*genCall=*/true);
1003  SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1004  genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1005  genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1006  scf::YieldOp::create(rewriter, loc, insertRet);
1007 
1008  rewriter.setInsertionPointAfter(loop);
1009  // Deallocate the buffers on exit of the full loop nest.
1010  Operation *parent = getTop(op);
1011  rewriter.setInsertionPointAfter(parent);
1012  memref::DeallocOp::create(rewriter, loc, values);
1013  memref::DeallocOp::create(rewriter, loc, filled);
1014  memref::DeallocOp::create(rewriter, loc, added);
1015  // Replace operation with resulting memrefs.
1016  rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1017  return success();
1018  }
1019 };
1020 
1021 /// Sparse codegen rule for the insert operator.
1022 class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1023 public:
1025  LogicalResult
1026  matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1027  ConversionPatternRewriter &rewriter) const override {
1028  auto stt = getSparseTensorType(op.getDest());
1029  if (!stt.hasEncoding())
1030  return failure();
1031  assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1032 
1033  Location loc = op.getLoc();
1034  auto desc =
1035  getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1036  TypeRange flatSpTensorTps = desc.getFields().getTypes();
1037  SmallVector<Value> params = llvm::to_vector(desc.getFields());
1038  SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1039  params.append(flatIndices.begin(), flatIndices.end());
1040  params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1041  SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1042  params, /*genCall=*/true);
1043  SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1044  // Replace operation with resulting memrefs.
1045  rewriter.replaceOpWithMultiple(op, {ret});
1046  return success();
1047  }
1048 };
1049 
1050 /// Sparse codegen rule for position accesses.
1051 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1052 public:
1053  using OpAdaptor = typename ToPositionsOp::Adaptor;
1055  LogicalResult
1056  matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1057  ConversionPatternRewriter &rewriter) const override {
1058  // Replace the requested position access with corresponding field.
1059  // The view is restricted to the actual size to ensure clients
1060  // of this operation truly observe size, not capacity!
1061  Location loc = op.getLoc();
1062  Level lvl = op.getLevel();
1063  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1064  op.getTensor().getType());
1065  auto mem = desc.getPosMemRef(lvl);
1066  auto size = desc.getPosMemSize(rewriter, loc, lvl);
1067  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1068  return success();
1069  }
1070 };
1071 
1072 /// Sparse codegen rule for accessing the coordinates arrays.
1073 class SparseToCoordinatesConverter
1074  : public OpConversionPattern<ToCoordinatesOp> {
1075 public:
1076  using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1078  LogicalResult
1079  matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1080  ConversionPatternRewriter &rewriter) const override {
1081  // Replace the requested coordinates access with corresponding field.
1082  // The view is restricted to the actual size to ensure clients
1083  // of this operation truly observe size, not capacity!
1084  Location loc = op.getLoc();
1085  Level lvl = op.getLevel();
1086  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1087  op.getTensor().getType());
1088  auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1089  if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1090  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1091  mem = genSliceToSize(rewriter, loc, mem, size);
1092  }
1093  rewriter.replaceOp(op, mem);
1094  return success();
1095  }
1096 };
1097 
1098 /// Sparse codegen rule for accessing the linear coordinates buffer.
1099 class SparseToCoordinatesBufferConverter
1100  : public OpConversionPattern<ToCoordinatesBufferOp> {
1101 public:
1102  using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1104  LogicalResult
1105  matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1106  ConversionPatternRewriter &rewriter) const override {
1107  // Replace the requested coordinates access with corresponding field.
1108  // The view is restricted to the actual size to ensure clients
1109  // of this operation truly observe size, not capacity!
1110  Location loc = op.getLoc();
1111  Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1112  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1113  op.getTensor().getType());
1114  auto mem = desc.getAOSMemRef();
1115  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1116  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1117  return success();
1118  }
1119 };
1120 
1121 /// Sparse codegen rule for value accesses.
1122 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1123 public:
1124  using OpAdaptor = typename ToValuesOp::Adaptor;
1126  LogicalResult
1127  matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1128  ConversionPatternRewriter &rewriter) const override {
1129  // Replace the requested values access with corresponding field.
1130  // The view is restricted to the actual size to ensure clients
1131  // of this operation truly observe size, not capacity!
1132  Location loc = op.getLoc();
1133  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1134  op.getTensor().getType());
1135  auto mem = desc.getValMemRef();
1136  auto size = desc.getValMemSize(rewriter, loc);
1137  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1138  return success();
1139  }
1140 };
1141 
1142 /// Sparse codegen rule for the convert operator.
1143 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1144 public:
1146  LogicalResult
1147  matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1148  ConversionPatternRewriter &rewriter) const override {
1149  SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1150  SparseTensorEncodingAttr encSrc =
1151  getSparseTensorEncoding(op.getSource().getType());
1152  // The output tensor can not be a slice and those cases should have been
1153  // rejected by ConvertOp::verify() already.
1154  assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1155  // Different encoding (except for different bitwidth) should be handled by
1156  // rewriting.
1157  // We need further rewrites if the input tensor is a slice too.
1158  if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1159  encSrc.isSlice()) {
1160  return failure();
1161  }
1162 
1163  Type retElemTp = op.getResult().getType().getElementType();
1164  Type srcElemTp = op.getSource().getType().getElementType();
1165  // Fold the trivial cases.
1166  if (retElemTp == srcElemTp && encDst == encSrc) {
1167  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1168  return success();
1169  }
1170  //
1171  // Do element-wise type conversion without using InsertOp.
1172  //
1173  // for each memref in srcTensor:
1174  // dst = memref.alloc
1175  // if srcMemRefType != dstMemRefType:
1176  // for every dst[i] = cast(src[i])
1177  // else:
1178  // dst = memref.copy(src)
1179  Location loc = op.getLoc();
1180  auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1181  op.getSource().getType());
1182  SmallVector<Value> fields;
1184  SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1185  [&rewriter, &fields, srcDesc,
1186  loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1187  LevelType /*lt*/) -> bool {
1188  // Simply reuses the storage specifier as it is an SSA value.
1189  if (fKind == SparseTensorFieldKind::StorageSpec) {
1190  fields.push_back(srcDesc.getSpecifier());
1191  } else {
1192  // Allocates new memrefs
1193  Value srcMem = srcDesc.getMemRefField(fIdx);
1194  // TODO: We can instead use the actual memSize in specifier, that
1195  // would require a subViewOp to avoid overflow when copying
1196  // values.
1197  Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1198  auto dstMem = memref::AllocOp::create(rewriter, loc,
1199  cast<MemRefType>(fTp), sz);
1200  if (fTp != srcMem.getType()) {
1201  // Converts elements type.
1202  scf::buildLoopNest(
1203  rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1204  constantIndex(rewriter, loc, 1),
1205  [srcMem, &dstMem](OpBuilder &builder, Location loc,
1206  ValueRange ivs) {
1207  Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1208  Value casted = genCast(builder, loc, v,
1209  dstMem.getType().getElementType());
1210  memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1211  });
1212  } else {
1213  // TODO: We can even reuse the same memref for the new tensor,
1214  // but that requires a `ref-counting` based memory management
1215  // for shared memrefs between multiple sparse tensors.
1216  memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1217  }
1218  fields.push_back(dstMem);
1219  }
1220  return true;
1221  });
1222 
1223  rewriter.replaceOpWithMultiple(op, {fields});
1224  return success();
1225  }
1226 };
1227 
1228 class SparseExtractSliceConverter
1229  : public OpConversionPattern<tensor::ExtractSliceOp> {
1230 public:
1232  LogicalResult
1233  matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1234  ConversionPatternRewriter &rewriter) const override {
1235  Location loc = op.getLoc();
1236  MLIRContext *ctx = op.getContext();
1237  auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1238  auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1239  // TODO: We should check these in ExtractSliceOp::verify.
1240  if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1241  return failure();
1242  assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1243 
1244  SmallVector<Value> fields;
1245  auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1246  op.getSource().getType());
1247 
1248  auto newSpec = StorageSpecifierInitOp::create(
1249  rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1250  desc.getSpecifier());
1251  desc.setSpecifier(newSpec);
1252 
1253  // Fills in slice information.
1254  for (auto [idx, offset, size, stride] : llvm::enumerate(
1255  op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1256  Dimension dim = idx;
1257 
1258  Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1259  Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1260  Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1261  // TODO: We could probably only set dynamic value here. But it would
1262  // requires us to fill the hole when casting a static slice to dynamic
1263  // slice.
1264  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1265  dim, offsetV);
1266 
1267  // FIXME: we need to distinguish level sizes and dimension size for slices
1268  // here. Maybe we should store slice level sizes in a different array
1269  // instead of reusing it.
1270  assert(srcEnc.isIdentity());
1271  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1272  sizeV);
1273  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1274  dim, strideV);
1275  }
1276 
1277  // NOTE: we can not generate tuples directly from descriptor here, as the
1278  // descriptor is holding the original type, yet we want the slice type
1279  // here (they shared every memref but with an updated specifier).
1280  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1281  return success();
1282  }
1283 };
1284 
1285 /// Sparse codegen rule for number of entries operator.
1286 class SparseNumberOfEntriesConverter
1287  : public OpConversionPattern<NumberOfEntriesOp> {
1288 public:
1290  LogicalResult
1291  matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1292  ConversionPatternRewriter &rewriter) const override {
1293  // Query memSizes for the actually stored values.
1294  // FIXME: the nse value computed in this way might be wrong when there is
1295  // any "loose_compressed" level.
1296  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1297  op.getTensor().getType());
1298  rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1299  return success();
1300  }
1301 };
1302 
1303 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1305  LogicalResult
1306  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1307  ConversionPatternRewriter &rewriter) const override {
1308  Location loc = op.getLoc();
1309  const auto stt = getSparseTensorType(op.getResult());
1310 
1311  SmallVector<Value> fields;
1312 
1314  stt,
1315  [&rewriter, &fields, &op, &stt,
1316  loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1317  Level /*lvl*/, LevelType lt) -> bool {
1318  assert(fields.size() == fIdx);
1319  if (fKind == SparseTensorFieldKind::StorageSpec) {
1320  fields.push_back(
1321  SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1322  } else {
1323  // Else simply takes the inputs.
1324  Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1325  ? op.getValues()
1326  : op.getLevels()[fIdx];
1327  // TODO: handle batch.
1328  TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1329  if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1330  // Flattens the buffer to batchLvlRank.
1331  auto reassoc = getReassociationForFlattening(
1332  mem.getType(), stt.getBatchLvlRank());
1333  mem = memref::CastOp::create(
1334  rewriter, loc, fType,
1335  memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1336  } else {
1337  mem = memref::CastOp::create(rewriter, loc, fType, mem);
1338  }
1339  fields.push_back(mem);
1340  }
1341  return true;
1342  });
1343 
1344  MutSparseTensorDescriptor desc(stt, fields);
1345  Value c0 = constantIndex(rewriter, loc, 0);
1346  Value c1 = constantIndex(rewriter, loc, 1);
1347  Value c2 = constantIndex(rewriter, loc, 2);
1348  Value posBack = c0; // index to the last value in the position array
1349  Value memSize = c1; // memory size for current array
1350 
1351  Level trailCOOStart = stt.getAoSCOOStart();
1352  Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1353  // Sets up SparseTensorSpecifier.
1354  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1355  assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1356 
1357  // Sets up the level size.
1358  auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1359  desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1360  // We use a single AOS array to store the trailing COO, so there is only
1361  // one memory size to set for the entire COO section.
1362  if (lvl > trailCOOStart)
1363  continue;
1364 
1365  // Sets up the memory size by reading the last value in position array.
1366  LevelType lt = stt.getLvlType(lvl);
1367  // Simply forwards the position index when this is a dense level.
1368  if (lt.isa<LevelFormat::Dense>()) {
1369  memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1370  posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1371  continue;
1372  }
1373  if (lt.isa<LevelFormat::Batch>()) {
1374  // Skips batch levels as it is not linearized.
1375  // FIXME: this assumes that every batch has the same number of nse, need
1376  // to be generalized to handle varied-size batches.
1377  continue;
1378  }
1379 
1380  if (isWithPosLT(lt)) {
1381  assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1382  if (isLooseCompressedLT(lt)) {
1383  memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1384  posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1385  } else {
1386  assert(isCompressedLT(lt));
1387  posBack = memSize;
1388  memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1389  }
1390  desc.setPosMemSize(rewriter, loc, lvl, memSize);
1391  // The last value in position array is the memory size for next level.
1392  // FIXME: this assumes that every batch has the same number of nse, need
1393  // to be generalized to handle varied-size batches.
1394  SmallVector<Value> batched(stt.getBatchLvlRank(),
1395  constantIndex(rewriter, loc, 0));
1396  batched.push_back(posBack);
1397  memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1398  posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1399  }
1400  assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1401  // FIXME: This seems to be unnecessarily complex, can we simplify it?
1402  if (lvl == trailCOOStart) {
1403  Value cooSz = arith::MulIOp::create(
1404  rewriter, loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1405  desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1406  } else {
1407  desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1408  }
1409  }
1410  desc.setValMemSize(rewriter, loc, memSize);
1411 
1412  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1413  return success();
1414  }
1415 };
1416 
1417 struct SparseDisassembleOpConverter
1418  : public OpConversionPattern<DisassembleOp> {
1420  SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1421  MLIRContext *context)
1422  : OpConversionPattern(typeConverter, context) {}
1423 
1424  LogicalResult
1425  matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1426  ConversionPatternRewriter &rewriter) const override {
1427  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1428  op.getTensor().getType());
1429  Location loc = op.getLoc();
1430  SmallVector<Value> retMem;
1431  SmallVector<Value> retLen;
1432  desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1433  &retLen](FieldIndex fid,
1434  SparseTensorFieldKind fKind,
1435  Level lvl, LevelType lt) -> bool {
1436  if (fKind == SparseTensorFieldKind::StorageSpec)
1437  return true;
1439  Value sz, src;
1441  if (fKind == SparseTensorFieldKind::ValMemRef) {
1442  sz = desc.getValMemSize(rewriter, loc);
1443  src = desc.getValMemRef();
1444  dst = genToMemref(rewriter, loc, op.getOutValues());
1445 
1446  retMem.push_back(dst);
1447  Type valLenTp = op.getValLen().getType();
1448  retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1449  } else {
1450  assert(fKind == SparseTensorFieldKind::PosMemRef ||
1451  fKind == SparseTensorFieldKind::CrdMemRef);
1452 
1453  sz = fKind == SparseTensorFieldKind::PosMemRef
1454  ? desc.getPosMemSize(rewriter, loc, lvl)
1455  : desc.getCrdMemSize(rewriter, loc, lvl);
1456  src = desc.getMemRefField(fid);
1457  dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1458  retMem.push_back(dst);
1459  // Retrieves the corresponding level length type.
1460  Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1461  retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1462  }
1463  Value flatOut = dst;
1464  if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1465  auto reassoc =
1466  getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1467  flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1468  }
1469  Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1470  Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1471  memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1472  return true;
1473  });
1474 
1475  // Converts MemRefs back to Tensors.
1476  SmallVector<Value> retValues = llvm::to_vector(
1477  llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1478  return bufferization::ToTensorOp::create(
1479  rewriter, loc, memref::getTensorTypeFromMemRefType(v.getType()),
1480  v);
1481  }));
1482  // Appends the actual memory length used in each buffer returned.
1483  retValues.append(retLen.begin(), retLen.end());
1484  rewriter.replaceOp(op, retValues);
1485  return success();
1486  }
1487 };
1488 
1489 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1491  LogicalResult
1492  matchAndRewrite(NewOp op, OpAdaptor adaptor,
1493  ConversionPatternRewriter &rewriter) const override {
1494  Location loc = op.getLoc();
1495  const auto dstTp = getSparseTensorType(op.getResult());
1496  // Creating COO with NewOp is handled by direct IR codegen. All other cases
1497  // are handled by rewriting.
1498  if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1499  return failure();
1500 
1501  // Implement as follows:
1502  // %reader = @createCheckedSparseTensorReader(%filename)
1503  // %nse = @getSparseTensorNSE(%reader)
1504  // %coo = bufferization.alloc_tensor an ordered COO with
1505  // dst dim ordering, size_hint = %nse
1506  // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1507  // %values = sparse_tensor.values(%coo)
1508  // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1509  // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1510  // update storage specifier
1511  // @delSparseTensorReader(%reader)
1512  SmallVector<Value> dimSizesValues;
1513  Value dimSizesBuffer;
1514  Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1515  dimSizesValues, dimSizesBuffer);
1516 
1517  // Get the number of stored entries.
1518  const Type indexTp = rewriter.getIndexType();
1519  Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1520  {indexTp}, {reader}, EmitCInterface::Off)
1521  .getResult(0);
1522 
1523  // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1524  SmallVector<Value> lvlSizesValues;
1525  Value dim2lvlBuffer;
1526  Value lvl2dimBuffer;
1527  genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1528  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1529 
1530  // Construct allocation for each field.
1531  Value sizeHint = nse;
1532  SmallVector<Value> fields;
1533  createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1534  lvlSizesValues, fields);
1535 
1536  // Read the COO tensor data.
1537  MutSparseTensorDescriptor desc(dstTp, fields);
1538  Value xs = desc.getAOSMemRef();
1539  Value ys = desc.getValMemRef();
1540  const Type boolTp = rewriter.getIntegerType(1);
1541  const Type elemTp = dstTp.getElementType();
1542  const Type crdTp = dstTp.getCrdType();
1543  SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1545  primaryTypeFunctionSuffix(elemTp)};
1546  Value isSorted =
1547  createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1548  {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1549  EmitCInterface::On)
1550  .getResult(0);
1551 
1552  // If the destination tensor is a sorted COO, we need to sort the COO tensor
1553  // data if the input elements aren't sorted yet.
1554  const Level lvlRank = dstTp.getLvlRank();
1555  if (dstTp.isOrderedLvl(lvlRank - 1)) {
1556  Value kFalse = constantI1(rewriter, loc, false);
1557  Value notSorted = arith::CmpIOp::create(
1558  rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1559  scf::IfOp ifOp =
1560  scf::IfOp::create(rewriter, loc, notSorted, /*else*/ false);
1561  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1562  auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1563  SortOp::create(rewriter, loc, nse, xs, ValueRange{ys}, xPerm,
1564  rewriter.getIndexAttr(0),
1565  SparseTensorSortKind::HybridQuickSort);
1566  rewriter.setInsertionPointAfter(ifOp);
1567  }
1568 
1569  // Set PosMemRef0[1] = nse.
1570  const Value c1 = constantIndex(rewriter, loc, 1);
1571  const Value posMemref0 = desc.getPosMemRef(0);
1572  const Type posTp = dstTp.getPosType();
1573  const Value posNse = genCast(rewriter, loc, nse, posTp);
1574  memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1575 
1576  // Update storage specifier.
1577  Value coordinatesSize = arith::MulIOp::create(
1578  rewriter, loc, nse, constantIndex(rewriter, loc, lvlRank));
1579  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1580  coordinatesSize);
1581  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1582  std::nullopt, nse);
1583 
1584  // Release the sparse tensor reader.
1585  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1586  EmitCInterface::Off);
1587 
1588  // Replace operation with resulting memrefs.
1589  rewriter.replaceOpWithMultiple(op, {fields});
1590  return success();
1591  }
1592 };
1593 
1594 struct SparseHasRuntimeLibraryConverter
1595  : public OpConversionPattern<HasRuntimeLibraryOp> {
1597  LogicalResult
1598  matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1599  ConversionPatternRewriter &rewriter) const override {
1600  auto i1Type = rewriter.getI1Type();
1601  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1602  op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1603  return success();
1604  }
1605 };
1606 
1607 } // namespace
1608 
1609 //===----------------------------------------------------------------------===//
1610 // Public method for populating conversion rules.
1611 //===----------------------------------------------------------------------===//
1612 
1613 /// Populates the given patterns list with conversion rules required for
1614 /// the sparsification of linear algebra operations.
1616  const TypeConverter &typeConverter, RewritePatternSet &patterns,
1617  bool createSparseDeallocs, bool enableBufferInitialization) {
1618  patterns.add<
1619  SparseAssembleOpConverter, SparseDisassembleOpConverter,
1620  SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1621  SparseCastConverter, SparseExtractSliceConverter,
1622  SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1623  SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1624  SparseSliceGetterOpConverter<ToSliceOffsetOp,
1625  StorageSpecifierKind::DimOffset>,
1626  SparseSliceGetterOpConverter<ToSliceStrideOp,
1627  StorageSpecifierKind::DimStride>,
1628  SparseToPositionsConverter, SparseToCoordinatesConverter,
1629  SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1630  SparseConvertConverter, SparseNewConverter,
1631  SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1632  typeConverter, patterns.getContext());
1633  patterns.add<SparseTensorDeallocConverter>(
1634  typeConverter, patterns.getContext(), createSparseDeallocs);
1635  patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1636  typeConverter, patterns.getContext(), enableBufferInitialization);
1637 }
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A helper class to simplify lowering operations with/without function calls.
Definition: CodegenUtils.h:78
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition: MemRefOps.cpp:60
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
bool isCompressedLT(LevelType lt)
Definition: Enums.h:415
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition: Enums.h:418
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326