MLIR  20.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library (other than for file I/O), and providing many more
14 // opportunities for subsequent compiler optimization of the generated code.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "Utils/CodegenUtils.h"
20 
32 
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Flattens a list of operands that may contain sparse tensors.
43 static void flattenOperands(ValueRange operands,
44  SmallVectorImpl<Value> &flattened) {
45  // In case of
46  // sparse_tensor, c, sparse_tensor
47  // ==>
48  // memref ..., c, memref ...
49  for (auto operand : operands) {
50  if (getSparseTensorEncoding(operand.getType())) {
51  auto tuple = getTuple(operand);
52  // An unrealized_conversion_cast will be inserted by type converter to
53  // inter-mix the gap between 1:N conversion between sparse tensors and
54  // fields. In this case, take the operands in the cast and replace the
55  // sparse tensor output with the flattened type array.
56  flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
57  } else {
58  flattened.push_back(operand);
59  }
60  }
61 }
62 
63 /// Generates a load with proper `index` typing.
64 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
65  idx = genCast(builder, loc, idx, builder.getIndexType());
66  return builder.create<memref::LoadOp>(loc, mem, idx);
67 }
68 
69 /// Generates a store with proper `index` typing and proper value.
70 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
71  Value idx) {
72  idx = genCast(builder, loc, idx, builder.getIndexType());
73  val = genCast(builder, loc, val,
74  cast<ShapedType>(mem.getType()).getElementType());
75  builder.create<memref::StoreOp>(loc, val, mem, idx);
76 }
77 
78 /// Creates a straightforward counting for-loop.
79 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
81  Value lower = Value()) {
82  Type indexType = builder.getIndexType();
83  if (!lower)
84  lower = constantZero(builder, loc, indexType);
85  Value one = constantOne(builder, loc, indexType);
86  scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
87  for (unsigned i = 0, e = fields.size(); i < e; i++)
88  fields[i] = forOp.getRegionIterArg(i);
89  builder.setInsertionPointToStart(forOp.getBody());
90  return forOp;
91 }
92 
93 /// Creates a push back operation.
94 static void createPushback(OpBuilder &builder, Location loc,
96  SparseTensorFieldKind kind, std::optional<Level> lvl,
97  Value value, Value repeat = Value()) {
98  Type etp = desc.getMemRefElementType(kind, lvl);
99  Value field = desc.getMemRefField(kind, lvl);
100  StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
101 
102  auto pushBackOp = builder.create<PushBackOp>(
103  loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
104  genCast(builder, loc, value, etp), repeat);
105 
106  desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
107  desc.setSpecifierField(builder, loc, specFieldKind, lvl,
108  pushBackOp.getNewSize());
109 }
110 
111 /// Generates code that allocates a sparse storage scheme for given rank.
112 static void allocSchemeForRank(OpBuilder &builder, Location loc,
113  MutSparseTensorDescriptor desc, Level startLvl) {
114  const SparseTensorType stt(desc.getRankedTensorType());
115  Value linear = constantIndex(builder, loc, 1);
116  const Level lvlRank = stt.getLvlRank();
117  for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
118  const auto lt = stt.getLvlType(lvl);
119  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
120  // Append linear x positions, initialized to zero. Since each compressed
121  // dimension initially already has a single zero entry, this maintains
122  // the desired "linear + 1" length property at all times. For loose
123  // compression, we multiply linear by two in order to append both the
124  // lo/hi positions.
125  Value posZero = constantZero(builder, loc, stt.getPosType());
126  if (isLooseCompressedLT(lt)) {
127  Value two = constantIndex(builder, loc, 2);
128  linear = builder.create<arith::MulIOp>(loc, linear, two);
129  }
130  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
131  /*value=*/posZero, /*repeat=*/linear);
132  return;
133  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
134  return; // nothing to do
135  }
136  // Keep compounding the size, but nothing needs to be initialized
137  // at this level. We will eventually reach a compressed level or
138  // otherwise the values array for the from-here "all-dense" case.
139  assert(isDenseLT(lt));
140  Value size = desc.getLvlSize(builder, loc, lvl);
141  linear = builder.create<arith::MulIOp>(loc, linear, size);
142  }
143  // Reached values array so prepare for an insertion.
144  Value valZero = constantZero(builder, loc, stt.getElementType());
146  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
147 }
148 
149 /// Creates allocation operation.
151  MemRefType memRefType, Value sz,
152  bool enableInit) {
153  Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
154  Type elemType = memRefType.getElementType();
155  if (enableInit) {
156  Value fillValue = constantZero(builder, loc, elemType);
157  builder.create<linalg::FillOp>(loc, fillValue, buffer);
158  }
159  return buffer;
160 }
161 
162 /// Creates the dim sizes array, filling in from dynamic sizes.
163 static void createDimSizes(OpBuilder &builder, Location loc,
164  SparseTensorType stt, ValueRange dynSizes,
165  /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
166  const Dimension dimRank = stt.getDimRank();
167  dimSizesValues.clear();
168  dimSizesValues.reserve(dimRank);
169  unsigned i = 0;
170  for (const Size sz : stt.getDimShape())
171  dimSizesValues.push_back(ShapedType::isDynamic(sz)
172  ? dynSizes[i++]
173  : constantIndex(builder, loc, sz));
174 }
175 
176 /// Creates allocation for each field in sparse tensor type. Note that
177 /// for all dynamic memrefs in the sparse tensor stroage layout, the
178 /// memory size is really the capacity of the "vector", while the actual
179 /// size resides in the sizes array.
180 static void createAllocFields(OpBuilder &builder, Location loc,
181  SparseTensorType stt, bool enableInit,
182  Value sizeHint,
183  SmallVectorImpl<Value> &lvlSizesValues,
184  /*out*/ SmallVectorImpl<Value> &fields) {
185  Level lvlRank = stt.getLvlRank();
186  // Set up some heuristic sizes. We try to set the initial
187  // size based on available information. Otherwise we just
188  // initialize a few elements to start the reallocation chain.
189  // TODO: refine this
190  Value posHeuristic, crdHeuristic, valHeuristic;
191  if (stt.isAllDense()) {
192  valHeuristic = lvlSizesValues[0];
193  for (Level lvl = 1; lvl < lvlRank; lvl++)
194  valHeuristic =
195  builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196  } else if (sizeHint) {
197  if (stt.getAoSCOOStart() == 0) {
198  posHeuristic = constantIndex(builder, loc, 2);
199  crdHeuristic = builder.create<arith::MulIOp>(
200  loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
201  } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
202  posHeuristic = builder.create<arith::AddIOp>(
203  loc, sizeHint, constantIndex(builder, loc, 1));
204  crdHeuristic = sizeHint;
205  } else {
206  posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
207  }
208  valHeuristic = sizeHint;
209  } else {
210  posHeuristic = crdHeuristic = valHeuristic =
211  constantIndex(builder, loc, 16);
212  }
213  // Initializes all fields. An initial storage specifier and allocated
214  // positions/coordinates/values memrefs (with heuristic capacity).
216  stt,
217  [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
218  enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
219  Level /*lvl*/, LevelType /*lt*/) -> bool {
220  assert(fields.size() == fIdx);
221  Value field;
222  switch (fKind) {
224  field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
225  break;
227  field = createAllocation(builder, loc, cast<MemRefType>(fType),
228  posHeuristic, enableInit);
229  break;
231  field = createAllocation(builder, loc, cast<MemRefType>(fType),
232  crdHeuristic, enableInit);
233  break;
235  field = createAllocation(builder, loc, cast<MemRefType>(fType),
236  valHeuristic, enableInit);
237  break;
238  }
239  assert(field);
240  fields.push_back(field);
241  // Returns true to continue the iteration.
242  return true;
243  });
244  // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
245  // and gives all position fields an initial zero entry, so that it is
246  // easier to maintain the "linear + 1" length property.
247  MutSparseTensorDescriptor desc(stt, fields);
248  Value posZero = constantZero(builder, loc, stt.getPosType());
249  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
250  desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
251  const auto lt = stt.getLvlType(lvl);
252  if (isCompressedLT(lt) || isLooseCompressedLT(lt))
253  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
254  /*value=*/posZero);
255  }
256  allocSchemeForRank(builder, loc, desc, /*rank=*/0);
257 }
258 
259 /// Helper method that generates block specific to compressed case:
260 ///
261 /// // given: parentPos = posCursor[lvl-1]
262 /// pstart = desc.positions[lvl][parentPos]
263 /// pstop = desc.positions[lvl][parentPos+1]
264 /// plast = pstop - 1
265 /// msz = desc.coordinates[lvl].size()
266 /// if (pstart < pstop) {
267 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
268 /// } else { // first insertion
269 /// isPresent = false
270 /// desc.positions[lvl][parentPos] = msz
271 /// }
272 /// if (isPresent) { // coordinate is already present
273 /// pnext = plast
274 /// } else {
275 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
276 /// desc.positions[lvl][parentPos+1] = msz+1
277 /// pnext = msz
278 /// <prepare level lvl+1>
279 /// }
280 /// posCursor[lvl] = pnext
281 static Value genCompressed(OpBuilder &builder, Location loc,
282  MutSparseTensorDescriptor desc, ValueRange lvlCoords,
283  Value /*unused*/, Value parentPos, Level lvl) {
284  const SparseTensorType stt(desc.getRankedTensorType());
285  const Level lvlRank = stt.getLvlRank();
286  assert(lvl < lvlRank && "Level is out of bounds");
287  assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
288  "Level-rank mismatch");
289  SmallVector<Type> types;
290  Type indexType = builder.getIndexType();
291  Type boolType = builder.getIntegerType(1);
292  unsigned crdFidx;
293  unsigned crdStride;
294  std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
295  const Value one = constantIndex(builder, loc, 1);
296  const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
297  const Value positionsAtLvl = desc.getPosMemRef(lvl);
298  const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
299  const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
300  const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
301  const Value crdStrideC =
302  crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
303  const Value msz =
304  crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
305  : crdMsz;
306  const Value plast = builder.create<arith::SubIOp>(
307  loc, genCast(builder, loc, pstop, indexType), one);
308  // Conditional expression.
309  Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
310  pstart, pstop);
311  types.push_back(boolType);
312  scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
313  types.pop_back();
314  builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
315  Value crd =
316  genLoad(builder, loc, desc.getMemRefField(crdFidx),
317  crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
318  : plast);
319  Value eq = builder.create<arith::CmpIOp>(
320  loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
321  lvlCoords[lvl]);
322  builder.create<scf::YieldOp>(loc, eq);
323  builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
324  if (lvl > 0)
325  genStore(builder, loc, msz, positionsAtLvl, parentPos);
326  builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
327  builder.setInsertionPointAfter(ifOp1);
328  // If present construct. Note that for a non-unique dimension level, we
329  // simply set the condition to false and rely on CSE/DCE to clean up the IR.
330  //
331  // TODO: generate less temporary IR?
332  //
333  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
334  types.push_back(desc.getField(i).getType());
335  types.push_back(indexType);
336  const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
337  : constantI1(builder, loc, false);
338  scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
339  // If present (fields unaffected, update pnext to plast).
340  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
341 
342  // FIXME: This does not looks like a clean way, but probably the most
343  // efficient way.
344  desc.getFields().push_back(plast);
345  builder.create<scf::YieldOp>(loc, desc.getFields());
346  desc.getFields().pop_back();
347 
348  // If !present (changes fields, update pnext).
349  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
350  Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
351  genStore(builder, loc, mszp1, positionsAtLvl, pp1);
352  createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
353  /*value=*/lvlCoords[lvl]);
354  // Prepare the next level "as needed".
355  if ((lvl + 1) < lvlRank)
356  allocSchemeForRank(builder, loc, desc, lvl + 1);
357 
358  desc.getFields().push_back(msz);
359  builder.create<scf::YieldOp>(loc, desc.getFields());
360  desc.getFields().pop_back();
361 
362  // Update fields and return next pos.
363  builder.setInsertionPointAfter(ifOp2);
364  unsigned o = 0;
365  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
366  desc.setField(i, ifOp2.getResult(o++));
367  return ifOp2.getResult(o);
368 }
369 
370 /// Generates insertion finalization code.
371 static void genEndInsert(OpBuilder &builder, Location loc,
372  SparseTensorDescriptor desc) {
373  const SparseTensorType stt(desc.getRankedTensorType());
374  const Level lvlRank = stt.getLvlRank();
375  for (Level lvl = 0; lvl < lvlRank; lvl++) {
376  const auto lt = stt.getLvlType(lvl);
377  if (isCompressedLT(lt)) {
378  // Compressed dimensions need a position cleanup for all entries
379  // that were not visited during the insertion pass.
380  //
381  // TODO: avoid cleanup and keep compressed scheme consistent at all
382  // times?
383  //
384  if (lvl > 0) {
385  Type posType = stt.getPosType();
386  Value posMemRef = desc.getPosMemRef(lvl);
387  Value hi = desc.getPosMemSize(builder, loc, lvl);
388  Value zero = constantIndex(builder, loc, 0);
389  Value one = constantIndex(builder, loc, 1);
390  // Vector of only one, but needed by createFor's prototype.
391  SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
392  scf::ForOp loop = createFor(builder, loc, hi, inits, one);
393  Value i = loop.getInductionVar();
394  Value oldv = loop.getRegionIterArg(0);
395  Value newv = genLoad(builder, loc, posMemRef, i);
396  Value posZero = constantZero(builder, loc, posType);
397  Value cond = builder.create<arith::CmpIOp>(
398  loc, arith::CmpIPredicate::eq, newv, posZero);
399  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
400  cond, /*else*/ true);
401  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
402  genStore(builder, loc, oldv, posMemRef, i);
403  builder.create<scf::YieldOp>(loc, oldv);
404  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
405  builder.create<scf::YieldOp>(loc, newv);
406  builder.setInsertionPointAfter(ifOp);
407  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
408  builder.setInsertionPointAfter(loop);
409  }
410  } else {
411  assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
412  isNOutOfMLT(lt));
413  }
414  }
415 }
416 
417 /// Generates a subview into the sizes.
418 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
419  Value sz) {
420  auto memTp = llvm::cast<MemRefType>(mem.getType());
421  // For higher-dimensional memrefs, we assume that the innermost
422  // dimension is always of the right size.
423  // TODO: generate complex truncating view here too?
424  if (memTp.getRank() > 1)
425  return mem;
426  // Truncate linear memrefs to given size.
427  return builder
428  .create<memref::SubViewOp>(
429  loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
430  mem, ValueRange{}, ValueRange{sz}, ValueRange{},
431  ArrayRef<int64_t>{0}, // static offset
432  ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
433  ArrayRef<int64_t>{1}) // static stride
434  .getResult();
435 }
436 
437 /// Creates the reassociation array.
439 getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
440  SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
441  // Create reassociation in the form:
442  // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
443  for (unsigned i = 0; i < batchLvls; i++)
444  ret[i].push_back(i);
445 
446  for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
447  ret.back().push_back(i);
448 
449  return ret;
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // Codegen rules.
454 //===----------------------------------------------------------------------===//
455 
456 namespace {
457 
458 /// Helper class to help lowering sparse_tensor.insert operation.
459 class SparseInsertGenerator
460  : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
461 public:
462  SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
463  bool genCall)
464  : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
465 
466  /// Generates code along an insertion path without the need for a "cursor".
467  /// This current insertion strategy comes at the expense of some testing
468  /// overhead for each insertion. The strategy will be optimized later for
469  /// common insertion patterns. The current insertion strategy also assumes
470  /// insertions occur in "a reasonable order" that enables building the
471  /// storage scheme in an appending/inserting kind of fashion (i.e. no
472  /// in-between insertions that need data movement). The implementation
473  /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
474  ///
475  /// TODO: better unord/not-unique; also generalize, optimize, specialize!
476  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
477  OpBuilder &builder, Location loc) {
478  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
479  const Level lvlRank = stt.getLvlRank();
480  // Extract fields and coordinates from args.
481  SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
482  MutSparseTensorDescriptor desc(stt, fields);
483  const SmallVector<Value> coords =
484  llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
485  Value value = args.back();
486  Value parentPos = constantZero(builder, loc, builder.getIndexType());
487  // Generate code for every level.
488  for (Level lvl = 0; lvl < lvlRank; lvl++) {
489  const auto lt = stt.getLvlType(lvl);
490  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
491  // Create:
492  // if (!present) {
493  // coordinates[lvl].push_back(coords[lvl])
494  // <update positions and prepare level lvl + 1>
495  // }
496  // positions[lvl] = coordinates.size() - 1
497  // <insert @ positions[lvl] at next level lvl + 1>
498  if (isLooseCompressedLT(lt)) {
499  Value two = constantIndex(builder, loc, 2);
500  parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
501  }
502  parentPos =
503  genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
504  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
505  // Create:
506  // coordinates[lvl].push_back(coords[lvl])
507  // positions[lvl] = positions[lvl-1]
508  // <insert @ positions[lvl] at next level lvl + 1>
510  lvl, /*value=*/coords[lvl]);
511  } else {
512  assert(isDenseLT(lt));
513  // Construct the new position as:
514  // positions[lvl] = size * positions[lvl-1] + coords[lvl]
515  // <insert @ positions[lvl] at next level lvl + 1>
516  Value size = desc.getLvlSize(builder, loc, lvl);
517  Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
518  parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
519  }
520  }
521  // Reached the actual value append/insert.
522  if (!stt.isDenseLvl(lvlRank - 1))
524  std::nullopt, value);
525  else
526  genStore(builder, loc, value, desc.getValMemRef(), parentPos);
527  return fields;
528  }
529 
530  std::string getMangledFuncName() {
531  // The mangled name of the function has this format:
532  // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
533  constexpr const char kInsertFuncNamePrefix[] = "_insert_";
534  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
535  SmallString<32> nameBuffer;
536  llvm::raw_svector_ostream nameOstream(nameBuffer);
537  nameOstream << kInsertFuncNamePrefix;
538  const Level lvlRank = stt.getLvlRank();
539  for (Level l = 0; l < lvlRank; l++) {
540  std::string lvlType = toMLIRString(stt.getLvlType(l));
541  // Replace/remove punctuations in level properties.
542  std::replace_if(
543  lvlType.begin(), lvlType.end(),
544  [](char c) { return c == '(' || c == ','; }, '_');
545  llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
546  nameOstream << lvlType << "_";
547  }
548  // Static dim sizes are used in the generated code while dynamic sizes are
549  // loaded from the dimSizes buffer. This is the reason for adding the shape
550  // to the function name.
551  for (const auto sz : stt.getDimShape())
552  nameOstream << sz << "_";
553  // Permutation information is also used in generating insertion.
554  if (!stt.isIdentity())
555  nameOstream << stt.getDimToLvl() << "_";
556  nameOstream << stt.getElementType() << "_";
557  nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
558  return nameOstream.str().str();
559  }
560 
561 private:
562  TensorType rtp;
563 };
564 
565 /// Sparse tensor storage conversion rule for returns.
566 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
567 public:
569  LogicalResult
570  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
571  ConversionPatternRewriter &rewriter) const override {
572  SmallVector<Value> flattened;
573  flattenOperands(adaptor.getOperands(), flattened);
574  // Create a return with the flattened value extracted from sparse tensors.
575  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
576  return success();
577  }
578 };
579 
580 /// Sparse tensor storage conversion rule for calls.
581 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
582 public:
583  // The default CallOp converter can not handle 1:N type conversion.
585  LogicalResult
586  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
587  ConversionPatternRewriter &rewriter) const override {
588  Location loc = op.getLoc();
589  // In case of:
590  // sparse_tensor, f, sparse_tensor = call @foo(...)
591  // ==>
592  // memref..., f, memref = call @foo(...) replace with
593  // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
594  SmallVector<Type> finalRetTy;
595  if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
596  return failure();
597 
598  // (1) Generates new call with flattened return value.
599  SmallVector<Value> flattened;
600  flattenOperands(adaptor.getOperands(), flattened);
601  auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
602  finalRetTy, flattened);
603  // (2) Gather sparse tensor returns.
604  SmallVector<SmallVector<Value>> packedResultVals;
605  // Tracks the offset of current return value (of the original call)
606  // relative to the new call (after sparse tensor flattening);
607  unsigned retOffset = 0;
608  // Temporal buffer to hold the flattened list of type for
609  // a sparse tensor.
610  SmallVector<Type> sparseFlat;
611  for (auto ret : op.getResults()) {
612  assert(retOffset < newCall.getNumResults());
613  auto retType = ret.getType();
614  if (failed(typeConverter->convertType(retType, sparseFlat)))
615  llvm_unreachable("Failed to convert type in sparse tensor codegen");
616 
617  // Converted types can not be empty when the type conversion succeed.
618  assert(!sparseFlat.empty());
619  if (sparseFlat.size() > 1) {
620  auto flatSize = sparseFlat.size();
621  packedResultVals.emplace_back();
622  llvm::append_range(packedResultVals.back(),
623  newCall.getResults().slice(retOffset, flatSize));
624  retOffset += flatSize;
625  } else {
626  // If this is an 1:1 conversion, no need for casting.
627  packedResultVals.emplace_back();
628  packedResultVals.back().push_back(newCall.getResult(retOffset));
629  retOffset++;
630  }
631  sparseFlat.clear();
632  }
633 
634  assert(packedResultVals.size() == op.getNumResults());
635  rewriter.replaceOpWithMultiple(
636  op, llvm::to_vector_of<ValueRange>(packedResultVals));
637  return success();
638  }
639 };
640 
641 /// Sparse codegen rule for level accesses.
642 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
643 public:
645  LogicalResult
646  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
647  ConversionPatternRewriter &rewriter) const override {
648  std::optional<int64_t> lvl = op.getConstantLvlIndex();
649  RankedTensorType srcType = op.getSource().getType();
650  if (!lvl || !getSparseTensorEncoding(srcType))
651  return failure();
652 
653  auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
654  auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
655 
656  rewriter.replaceOp(op, sz);
657  return success();
658  }
659 };
660 
661 // TODO: use a new SortCOO operation here instead of reusing convert op.
662 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
664  LogicalResult
665  matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const override {
667  Location loc = op.getLoc();
668  MLIRContext *ctx = op.getContext();
669 
670  SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
671  SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
672 
673  // Should have been verified.
674  assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
675  dstStt.isCOOType() && srcStt.isCOOType());
676  assert(dstStt.hasSameDimToLvl(srcStt));
677 
678  // We don't need a mutable descriptor here as we perform sorting in-place.
679  auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
680  op.getInputCoo().getType());
681  auto nnz = desc.getValMemSize(rewriter, op.getLoc());
682  auto crd = desc.getAOSMemRef();
683  auto val = desc.getValMemRef();
684 
685  // Otherwise we need another data shuffle and a non-identity map.
686  assert(dstStt.hasSameDimToLvl(srcStt));
687  (void)dstStt; // to silence warning when assertion is disabled
688 
689  auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
690 
691  rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
692  rewriter.getIndexAttr(0), op.getAlgorithm());
693 
694  // Since we do in-place sorting, the destinate tensor will have the same set
695  // of memrefs as the source tensor.
696  rewriter.replaceOp(op, adaptor.getInputCoo());
697  return success();
698  }
699 };
700 
701 template <typename Op, StorageSpecifierKind kind>
702 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
703 public:
705  LogicalResult
706  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
707  ConversionPatternRewriter &rewriter) const override {
708  // Simply lowers to specifer.get <field> operation.
709  auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
710  op.getSlice().getType());
711  auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
712  op.getDim().getZExtValue());
713 
714  rewriter.replaceOp(op, v);
715  return success();
716  }
717 };
718 
719 /// Sparse codegen rule for trivial tensor casts.
720 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
721 public:
723  LogicalResult
724  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
725  ConversionPatternRewriter &rewriter) const override {
726  // Only rewrite identically annotated source/dest.
727  auto encDst = getSparseTensorEncoding(op.getType());
728  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
729  if (!encDst || encDst != encSrc)
730  return failure();
731  rewriter.replaceOp(op, adaptor.getOperands());
732  return success();
733  }
734 };
735 
736 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
737 public:
739  LogicalResult
740  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
741  ConversionPatternRewriter &rewriter) const override {
742  // Simply fold the operation.
743  rewriter.replaceOp(op, adaptor.getSource());
744  return success();
745  }
746 };
747 
748 /// Sparse codegen rule for the alloc operator.
749 class SparseTensorAllocConverter
750  : public OpConversionPattern<bufferization::AllocTensorOp> {
751 public:
753  SparseTensorAllocConverter(const TypeConverter &typeConverter,
754  MLIRContext *context, bool enableInit)
755  : OpConversionPattern(typeConverter, context),
756  enableBufferInitialization(enableInit) {}
757 
758  LogicalResult
759  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
760  ConversionPatternRewriter &rewriter) const override {
761  const auto resType = getSparseTensorType(op);
762  if (!resType.hasEncoding())
763  return failure();
764 
765  Location loc = op.getLoc();
766  // Deal with copy.
767  if (op.getCopy()) {
768  auto desc = getDescriptorFromTensorTuple(
769  adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
770  SmallVector<Value> fields;
771  fields.reserve(desc.getNumFields());
772  // Memcpy on memref fields.
773  for (auto field : desc.getMemRefFields()) {
774  auto memrefTp = cast<MemRefType>(field.getType());
775  auto size = rewriter.create<memref::DimOp>(loc, field, 0);
776  auto copied =
777  rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
778  rewriter.create<memref::CopyOp>(loc, field, copied);
779  fields.push_back(copied);
780  }
781  // Reuses specifier.
782  fields.push_back(desc.getSpecifier());
783  assert(fields.size() == desc.getNumFields());
784  rewriter.replaceOpWithMultiple(op, {fields});
785  return success();
786  }
787 
788  if (!resType.isIdentity()) {
789  return rewriter.notifyMatchFailure(
790  op, "try run --sparse-reinterpret-map before codegen");
791  }
792  // Level size equals to dimension size since lvl2dim map is an identity map.
793  SmallVector<Value> lvlSizesValues;
794  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
795  /*dimSizesValues=*/lvlSizesValues);
796 
797  // Construct allocation for each field.
798  Value sizeHint = op.getSizeHint();
799  SmallVector<Value> fields;
800  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
801  sizeHint, lvlSizesValues, fields);
802 
803  // Replace operation with resulting memrefs.
804  rewriter.replaceOpWithMultiple(op, {fields});
805  return success();
806  }
807 
808 private:
809  bool enableBufferInitialization;
810 };
811 
812 /// Sparse codegen rule for the empty tensor operator.
813 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
814 public:
816  SparseTensorEmptyConverter(const TypeConverter &typeConverter,
817  MLIRContext *context, bool enableInit)
818  : OpConversionPattern(typeConverter, context),
819  enableBufferInitialization(enableInit) {}
820 
821  LogicalResult
822  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
823  ConversionPatternRewriter &rewriter) const override {
824  const auto resType = getSparseTensorType(op);
825  if (!resType.hasEncoding())
826  return failure();
827 
828  if (!resType.isIdentity()) {
829  return rewriter.notifyMatchFailure(
830  op, "try run --sparse-reinterpret-map before codegen");
831  }
832 
833  Location loc = op.getLoc();
834  // Level size equals to dimension size since lvl2dim map is an identity map.
835  SmallVector<Value> lvlSizesValues;
836  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
837  /*dimSizesValues=*/lvlSizesValues);
838  // Construct allocation for each field.
839  Value sizeHint; // none
840  SmallVector<Value> fields;
841  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
842  sizeHint, lvlSizesValues, fields);
843 
844  // Replace operation with resulting memrefs.
845  rewriter.replaceOpWithMultiple(op, {fields});
846  return success();
847  }
848 
849 private:
850  bool enableBufferInitialization;
851 };
852 
853 /// Sparse codegen rule for the dealloc operator.
854 class SparseTensorDeallocConverter
855  : public OpConversionPattern<bufferization::DeallocTensorOp> {
856 public:
858  SparseTensorDeallocConverter(const TypeConverter &typeConverter,
859  MLIRContext *context, bool createDeallocs)
860  : OpConversionPattern(typeConverter, context),
861  createDeallocs(createDeallocs) {}
862 
863  LogicalResult
864  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
865  ConversionPatternRewriter &rewriter) const override {
866  auto enc = getSparseTensorEncoding(op.getTensor().getType());
867  if (!enc)
868  return failure();
869 
870  // If user requests not to deallocate sparse tensors, simply erase the
871  // operation.
872  if (createDeallocs) {
873  // Replace the sparse tensor deallocation with field deallocations.
874  Location loc = op.getLoc();
875  auto desc = getDescriptorFromTensorTuple(
876  adaptor.getTensor(),
877  cast<RankedTensorType>(op.getTensor().getType()));
878  for (auto input : desc.getMemRefFields())
879  // Deallocate every buffer used to store the sparse tensor handler.
880  rewriter.create<memref::DeallocOp>(loc, input);
881  }
882  rewriter.eraseOp(op);
883  return success();
884  }
885 
886 private:
887  const bool createDeallocs;
888 };
889 
890 /// Sparse codegen rule for tensor rematerialization.
891 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
892 public:
894  LogicalResult
895  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
896  ConversionPatternRewriter &rewriter) const override {
897  // Prepare descriptor.
898  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
899  op.getTensor().getType());
900  // Generate optional insertion finalization code.
901  if (op.getHasInserts())
902  genEndInsert(rewriter, op.getLoc(), desc);
903  // Replace operation with resulting memrefs.
904  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
905  return success();
906  }
907 };
908 
909 /// Sparse codegen rule for the expand op.
910 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
911 public:
913  LogicalResult
914  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
915  ConversionPatternRewriter &rewriter) const override {
916  if (!getSparseTensorEncoding(op.getTensor().getType()))
917  return failure();
918  Location loc = op->getLoc();
919  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
920  op.getTensor().getType());
921  const auto srcType = getSparseTensorType(op.getTensor());
922  Type eltType = srcType.getElementType();
923  Type boolType = rewriter.getIntegerType(1);
924  Type idxType = rewriter.getIndexType();
925  // All initialization should be done on entry of the loop nest.
926  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
927 
928  // Determine the size for access expansion (always the innermost stored
929  // level size).
930  const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
931  // Generate a memref for `sz` elements of type `t`.
932  const auto genAlloc = [&](Type t) {
933  const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
934  return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
935  };
936  // Allocate temporary buffers for values/filled-switch and added.
937  // We do not use stack buffers for this, since the expanded size may
938  // be rather large (as it envelops a single expanded dense dimension).
939  Value values = genAlloc(eltType);
940  Value filled = genAlloc(boolType);
941  Value added = genAlloc(idxType);
942  Value zero = constantZero(rewriter, loc, idxType);
943  // Reset the values/filled-switch to all-zero/false. Note that this
944  // introduces an O(N) operation into the computation, but this reset
945  // operation is amortized over the innermost loops for the access
946  // pattern expansion. As noted in the operation doc, we would like
947  // to amortize this setup cost even between kernels.
948  rewriter.create<linalg::FillOp>(
949  loc, ValueRange{constantZero(rewriter, loc, eltType)},
950  ValueRange{values});
951  rewriter.create<linalg::FillOp>(
952  loc, ValueRange{constantZero(rewriter, loc, boolType)},
953  ValueRange{filled});
954  // Replace expansion op with these buffers and initial coordinate.
955  assert(op.getNumResults() == 4);
956  rewriter.replaceOp(op, {values, filled, added, zero});
957  return success();
958  }
959 };
960 
961 /// Sparse codegen rule for the compress operator.
962 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
963 public:
965  LogicalResult
966  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
967  ConversionPatternRewriter &rewriter) const override {
968  Location loc = op->getLoc();
969  SmallVector<Value> fields;
970  auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
971  op.getTensor().getType());
972  Value values = adaptor.getValues();
973  Value filled = adaptor.getFilled();
974  Value added = adaptor.getAdded();
975  Value count = adaptor.getCount();
976  const SparseTensorType dstType(desc.getRankedTensorType());
977  Type eltType = dstType.getElementType();
978 
979  // If the innermost level is ordered, we need to sort the coordinates
980  // in the "added" array prior to applying the compression.
981  if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
982  rewriter.create<SortOp>(
983  loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
984  rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
985  // While performing the insertions, we also need to reset the elements
986  // of the values/filled-switch by only iterating over the set elements,
987  // to ensure that the runtime complexity remains proportional to the
988  // sparsity of the expanded access pattern.
989  //
990  // Generate
991  // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
992  // crd = added[i];
993  // value = values[crd];
994  // insert({lvlCoords, crd}, value);
995  // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
996  // values[crd] = 0;
997  // filled[crd] = false;
998  // yield new_memrefs
999  // }
1000  scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
1001  Value i = loop.getInductionVar();
1002 
1003  Value crd = genLoad(rewriter, loc, added, i);
1004  Value value = genLoad(rewriter, loc, values, crd);
1005  SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
1006  SmallVector<Type> flatSpTensorTps = llvm::to_vector(
1007  llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
1008  params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
1009  params.push_back(crd);
1010  params.push_back(value);
1011  SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1012  params, /*genCall=*/true);
1013  SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1014  genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1015  genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1016  rewriter.create<scf::YieldOp>(loc, insertRet);
1017 
1018  rewriter.setInsertionPointAfter(loop);
1019  // Deallocate the buffers on exit of the full loop nest.
1020  Operation *parent = getTop(op);
1021  rewriter.setInsertionPointAfter(parent);
1022  rewriter.create<memref::DeallocOp>(loc, values);
1023  rewriter.create<memref::DeallocOp>(loc, filled);
1024  rewriter.create<memref::DeallocOp>(loc, added);
1025  // Replace operation with resulting memrefs.
1026  rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1027  return success();
1028  }
1029 };
1030 
1031 /// Sparse codegen rule for the insert operator.
1032 class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1033 public:
1035  LogicalResult
1036  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
1037  ConversionPatternRewriter &rewriter) const override {
1038  auto stt = getSparseTensorType(adaptor.getDest());
1039  if (!stt.hasEncoding())
1040  return failure();
1041  assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1042 
1043  Location loc = op.getLoc();
1044  auto desc =
1045  getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1046  TypeRange flatSpTensorTps = desc.getFields().getTypes();
1047  SmallVector<Value> params = llvm::to_vector(desc.getFields());
1048  params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
1049  params.push_back(adaptor.getScalar());
1050  SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1051  params, /*genCall=*/true);
1052  SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1053  // Replace operation with resulting memrefs.
1054  rewriter.replaceOpWithMultiple(op, {ret});
1055  return success();
1056  }
1057 };
1058 
1059 /// Sparse codegen rule for position accesses.
1060 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1061 public:
1062  using OpAdaptor = typename ToPositionsOp::Adaptor;
1064  LogicalResult
1065  matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1066  ConversionPatternRewriter &rewriter) const override {
1067  // Replace the requested position access with corresponding field.
1068  // The view is restricted to the actual size to ensure clients
1069  // of this operation truly observe size, not capacity!
1070  Location loc = op.getLoc();
1071  Level lvl = op.getLevel();
1072  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1073  op.getTensor().getType());
1074  auto mem = desc.getPosMemRef(lvl);
1075  auto size = desc.getPosMemSize(rewriter, loc, lvl);
1076  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1077  return success();
1078  }
1079 };
1080 
1081 /// Sparse codegen rule for accessing the coordinates arrays.
1082 class SparseToCoordinatesConverter
1083  : public OpConversionPattern<ToCoordinatesOp> {
1084 public:
1085  using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1087  LogicalResult
1088  matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1089  ConversionPatternRewriter &rewriter) const override {
1090  // Replace the requested coordinates access with corresponding field.
1091  // The view is restricted to the actual size to ensure clients
1092  // of this operation truly observe size, not capacity!
1093  Location loc = op.getLoc();
1094  Level lvl = op.getLevel();
1095  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1096  op.getTensor().getType());
1097  auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1098  if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1099  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1100  mem = genSliceToSize(rewriter, loc, mem, size);
1101  }
1102  rewriter.replaceOp(op, mem);
1103  return success();
1104  }
1105 };
1106 
1107 /// Sparse codegen rule for accessing the linear coordinates buffer.
1108 class SparseToCoordinatesBufferConverter
1109  : public OpConversionPattern<ToCoordinatesBufferOp> {
1110 public:
1111  using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1113  LogicalResult
1114  matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1115  ConversionPatternRewriter &rewriter) const override {
1116  // Replace the requested coordinates access with corresponding field.
1117  // The view is restricted to the actual size to ensure clients
1118  // of this operation truly observe size, not capacity!
1119  Location loc = op.getLoc();
1120  Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1121  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1122  op.getTensor().getType());
1123  auto mem = desc.getAOSMemRef();
1124  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1125  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1126  return success();
1127  }
1128 };
1129 
1130 /// Sparse codegen rule for value accesses.
1131 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1132 public:
1133  using OpAdaptor = typename ToValuesOp::Adaptor;
1135  LogicalResult
1136  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1137  ConversionPatternRewriter &rewriter) const override {
1138  // Replace the requested values access with corresponding field.
1139  // The view is restricted to the actual size to ensure clients
1140  // of this operation truly observe size, not capacity!
1141  Location loc = op.getLoc();
1142  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1143  op.getTensor().getType());
1144  auto mem = desc.getValMemRef();
1145  auto size = desc.getValMemSize(rewriter, loc);
1146  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1147  return success();
1148  }
1149 };
1150 
1151 /// Sparse codegen rule for the convert operator.
1152 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1153 public:
1155  LogicalResult
1156  matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1157  ConversionPatternRewriter &rewriter) const override {
1158  SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1159  SparseTensorEncodingAttr encSrc =
1160  getSparseTensorEncoding(op.getSource().getType());
1161  // The output tensor can not be a slice and those cases should have been
1162  // rejected by ConvertOp::verify() already.
1163  assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1164  // Different encoding (except for different bitwidth) should be handled by
1165  // rewriting.
1166  // We need further rewrites if the input tensor is a slice too.
1167  if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1168  encSrc.isSlice()) {
1169  return failure();
1170  }
1171 
1172  Type retElemTp = op.getResult().getType().getElementType();
1173  Type srcElemTp = op.getSource().getType().getElementType();
1174  // Fold the trivial cases.
1175  if (retElemTp == srcElemTp && encDst == encSrc) {
1176  rewriter.replaceOp(op, adaptor.getSource());
1177  return success();
1178  }
1179  //
1180  // Do element-wise type conversion without using InsertOp.
1181  //
1182  // for each memref in srcTensor:
1183  // dst = memref.alloc
1184  // if srcMemRefType != dstMemRefType:
1185  // for every dst[i] = cast(src[i])
1186  // else:
1187  // dst = memref.copy(src)
1188  Location loc = op.getLoc();
1189  auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1190  op.getSource().getType());
1191  SmallVector<Value> fields;
1193  SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1194  [&rewriter, &fields, srcDesc,
1195  loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1196  LevelType /*lt*/) -> bool {
1197  // Simply reuses the storage specifier as it is an SSA value.
1198  if (fKind == SparseTensorFieldKind::StorageSpec) {
1199  fields.push_back(srcDesc.getSpecifier());
1200  } else {
1201  // Allocates new memrefs
1202  Value srcMem = srcDesc.getMemRefField(fIdx);
1203  // TODO: We can instead use the actual memSize in specifier, that
1204  // would require a subViewOp to avoid overflow when copying
1205  // values.
1206  Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1207  auto dstMem = rewriter.create<memref::AllocOp>(
1208  loc, cast<MemRefType>(fTp), sz);
1209  if (fTp != srcMem.getType()) {
1210  // Converts elements type.
1211  scf::buildLoopNest(
1212  rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1213  constantIndex(rewriter, loc, 1),
1214  [srcMem, &dstMem](OpBuilder &builder, Location loc,
1215  ValueRange ivs) {
1216  Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1217  Value casted = genCast(builder, loc, v,
1218  dstMem.getType().getElementType());
1219  builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1220  });
1221  } else {
1222  // TODO: We can even reuse the same memref for the new tensor,
1223  // but that requires a `ref-counting` based memory management
1224  // for shared memrefs between multiple sparse tensors.
1225  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1226  }
1227  fields.push_back(dstMem);
1228  }
1229  return true;
1230  });
1231 
1232  rewriter.replaceOpWithMultiple(op, {fields});
1233  return success();
1234  }
1235 };
1236 
1237 class SparseExtractSliceConverter
1238  : public OpConversionPattern<tensor::ExtractSliceOp> {
1239 public:
1241  LogicalResult
1242  matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1243  ConversionPatternRewriter &rewriter) const override {
1244  Location loc = op.getLoc();
1245  MLIRContext *ctx = op.getContext();
1246  auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1247  auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1248  // TODO: We should check these in ExtractSliceOp::verify.
1249  if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1250  return failure();
1251  assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1252 
1253  SmallVector<Value> fields;
1254  auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1255  op.getSource().getType());
1256 
1257  auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1258  loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1259  desc.setSpecifier(newSpec);
1260 
1261  // Fills in slice information.
1262  for (auto [idx, offset, size, stride] : llvm::enumerate(
1263  op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1264  Dimension dim = idx;
1265 
1266  Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1267  Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1268  Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1269  // TODO: We could probably only set dynamic value here. But it would
1270  // requires us to fill the hole when casting a static slice to dynamic
1271  // slice.
1272  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1273  dim, offsetV);
1274 
1275  // FIXME: we need to distinguish level sizes and dimension size for slices
1276  // here. Maybe we should store slice level sizes in a different array
1277  // instead of reusing it.
1278  assert(srcEnc.isIdentity());
1279  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1280  sizeV);
1281  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1282  dim, strideV);
1283  }
1284 
1285  // NOTE: we can not generate tuples directly from descriptor here, as the
1286  // descriptor is holding the original type, yet we want the slice type
1287  // here (they shared every memref but with an updated specifier).
1288  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1289  return success();
1290  }
1291 };
1292 
1293 /// Sparse codegen rule for number of entries operator.
1294 class SparseNumberOfEntriesConverter
1295  : public OpConversionPattern<NumberOfEntriesOp> {
1296 public:
1298  LogicalResult
1299  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1300  ConversionPatternRewriter &rewriter) const override {
1301  // Query memSizes for the actually stored values.
1302  // FIXME: the nse value computed in this way might be wrong when there is
1303  // any "loose_compressed" level.
1304  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1305  op.getTensor().getType());
1306  rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1307  return success();
1308  }
1309 };
1310 
1311 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1313  LogicalResult
1314  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1315  ConversionPatternRewriter &rewriter) const override {
1316  Location loc = op.getLoc();
1317  const auto stt = getSparseTensorType(op.getResult());
1318 
1319  SmallVector<Value> fields;
1320 
1322  stt,
1323  [&rewriter, &fields, &op, &stt,
1324  loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1325  Level /*lvl*/, LevelType lt) -> bool {
1326  assert(fields.size() == fIdx);
1327  if (fKind == SparseTensorFieldKind::StorageSpec) {
1328  fields.push_back(
1329  SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1330  } else {
1331  // Else simply takes the inputs.
1332  Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1333  ? op.getValues()
1334  : op.getLevels()[fIdx];
1335  // TODO: handle batch.
1336  TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1337  if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1338  // Flattens the buffer to batchLvlRank.
1339  auto reassoc = getReassociationForFlattening(
1340  mem.getType(), stt.getBatchLvlRank());
1341  mem = rewriter.create<memref::CastOp>(
1342  loc, fType,
1343  rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1344  } else {
1345  mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1346  }
1347  fields.push_back(mem);
1348  }
1349  return true;
1350  });
1351 
1352  MutSparseTensorDescriptor desc(stt, fields);
1353  Value c0 = constantIndex(rewriter, loc, 0);
1354  Value c1 = constantIndex(rewriter, loc, 1);
1355  Value c2 = constantIndex(rewriter, loc, 2);
1356  Value posBack = c0; // index to the last value in the position array
1357  Value memSize = c1; // memory size for current array
1358 
1359  Level trailCOOStart = stt.getAoSCOOStart();
1360  Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1361  // Sets up SparseTensorSpecifier.
1362  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1363  assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1364 
1365  // Sets up the level size.
1366  auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1367  desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1368  // We use a single AOS array to store the trailing COO, so there is only
1369  // one memory size to set for the entire COO section.
1370  if (lvl > trailCOOStart)
1371  continue;
1372 
1373  // Sets up the memory size by reading the last value in position array.
1374  LevelType lt = stt.getLvlType(lvl);
1375  // Simply forwards the position index when this is a dense level.
1376  if (lt.isa<LevelFormat::Dense>()) {
1377  memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1378  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1379  continue;
1380  }
1381  if (lt.isa<LevelFormat::Batch>()) {
1382  // Skips batch levels as it is not linearized.
1383  // FIXME: this assumes that every batch has the same number of nse, need
1384  // to be generalized to handle varied-size batches.
1385  continue;
1386  }
1387 
1388  if (isWithPosLT(lt)) {
1389  assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1390  if (isLooseCompressedLT(lt)) {
1391  memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1392  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1393  } else {
1394  assert(isCompressedLT(lt));
1395  posBack = memSize;
1396  memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1397  }
1398  desc.setPosMemSize(rewriter, loc, lvl, memSize);
1399  // The last value in position array is the memory size for next level.
1400  // FIXME: this assumes that every batch has the same number of nse, need
1401  // to be generalized to handle varied-size batches.
1402  SmallVector<Value> batched(stt.getBatchLvlRank(),
1403  constantIndex(rewriter, loc, 0));
1404  batched.push_back(posBack);
1405  memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1406  posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1407  }
1408  assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1409  // FIXME: This seems to be unnecessarily complex, can we simplify it?
1410  if (lvl == trailCOOStart) {
1411  Value cooSz = rewriter.create<arith::MulIOp>(
1412  loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1413  desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1414  } else {
1415  desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1416  }
1417  }
1418  desc.setValMemSize(rewriter, loc, memSize);
1419 
1420  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1421  return success();
1422  }
1423 };
1424 
1425 struct SparseDisassembleOpConverter
1426  : public OpConversionPattern<DisassembleOp> {
1428  SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1429  MLIRContext *context)
1430  : OpConversionPattern(typeConverter, context) {}
1431 
1432  LogicalResult
1433  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1434  ConversionPatternRewriter &rewriter) const override {
1435  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1436  op.getTensor().getType());
1437  Location loc = op.getLoc();
1438  SmallVector<Value> retMem;
1439  SmallVector<Value> retLen;
1440  desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1441  &retLen](FieldIndex fid,
1442  SparseTensorFieldKind fKind,
1443  Level lvl, LevelType lt) -> bool {
1444  if (fKind == SparseTensorFieldKind::StorageSpec)
1445  return true;
1447  Value sz, src;
1449  if (fKind == SparseTensorFieldKind::ValMemRef) {
1450  sz = desc.getValMemSize(rewriter, loc);
1451  src = desc.getValMemRef();
1452  dst = genToMemref(rewriter, loc, op.getOutValues());
1453 
1454  retMem.push_back(dst);
1455  Type valLenTp = op.getValLen().getType();
1456  retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1457  } else {
1458  assert(fKind == SparseTensorFieldKind::PosMemRef ||
1459  fKind == SparseTensorFieldKind::CrdMemRef);
1460 
1461  sz = fKind == SparseTensorFieldKind::PosMemRef
1462  ? desc.getPosMemSize(rewriter, loc, lvl)
1463  : desc.getCrdMemSize(rewriter, loc, lvl);
1464  src = desc.getMemRefField(fid);
1465  dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1466  retMem.push_back(dst);
1467  // Retrieves the corresponding level length type.
1468  Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1469  retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1470  }
1471  Value flatOut = dst;
1472  if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1473  auto reassoc =
1474  getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1475  flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1476  }
1477  Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1478  Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1479  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1480  return true;
1481  });
1482 
1483  // Converts MemRefs back to Tensors.
1484  SmallVector<Value> retValues = llvm::to_vector(
1485  llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1486  return rewriter.create<bufferization::ToTensorOp>(loc, v);
1487  }));
1488  // Appends the actual memory length used in each buffer returned.
1489  retValues.append(retLen.begin(), retLen.end());
1490  rewriter.replaceOp(op, retValues);
1491  return success();
1492  }
1493 };
1494 
1495 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1497  LogicalResult
1498  matchAndRewrite(NewOp op, OpAdaptor adaptor,
1499  ConversionPatternRewriter &rewriter) const override {
1500  Location loc = op.getLoc();
1501  const auto dstTp = getSparseTensorType(op.getResult());
1502  // Creating COO with NewOp is handled by direct IR codegen. All other cases
1503  // are handled by rewriting.
1504  if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1505  return failure();
1506 
1507  // Implement as follows:
1508  // %reader = @createCheckedSparseTensorReader(%filename)
1509  // %nse = @getSparseTensorNSE(%reader)
1510  // %coo = bufferization.alloc_tensor an ordered COO with
1511  // dst dim ordering, size_hint = %nse
1512  // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1513  // %values = sparse_tensor.values(%coo)
1514  // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1515  // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1516  // update storage specifier
1517  // @delSparseTensorReader(%reader)
1518  SmallVector<Value> dimSizesValues;
1519  Value dimSizesBuffer;
1520  Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1521  dimSizesValues, dimSizesBuffer);
1522 
1523  // Get the number of stored entries.
1524  const Type indexTp = rewriter.getIndexType();
1525  Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1526  {indexTp}, {reader}, EmitCInterface::Off)
1527  .getResult(0);
1528 
1529  // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1530  SmallVector<Value> lvlSizesValues;
1531  Value dim2lvlBuffer;
1532  Value lvl2dimBuffer;
1533  genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1534  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1535 
1536  // Construct allocation for each field.
1537  Value sizeHint = nse;
1538  SmallVector<Value> fields;
1539  createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1540  lvlSizesValues, fields);
1541 
1542  // Read the COO tensor data.
1543  MutSparseTensorDescriptor desc(dstTp, fields);
1544  Value xs = desc.getAOSMemRef();
1545  Value ys = desc.getValMemRef();
1546  const Type boolTp = rewriter.getIntegerType(1);
1547  const Type elemTp = dstTp.getElementType();
1548  const Type crdTp = dstTp.getCrdType();
1549  SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1551  primaryTypeFunctionSuffix(elemTp)};
1552  Value isSorted =
1553  createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1554  {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1555  EmitCInterface::On)
1556  .getResult(0);
1557 
1558  // If the destination tensor is a sorted COO, we need to sort the COO tensor
1559  // data if the input elements aren't sorted yet.
1560  const Level lvlRank = dstTp.getLvlRank();
1561  if (dstTp.isOrderedLvl(lvlRank - 1)) {
1562  Value kFalse = constantI1(rewriter, loc, false);
1563  Value notSorted = rewriter.create<arith::CmpIOp>(
1564  loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1565  scf::IfOp ifOp =
1566  rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1567  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1568  auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1569  rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1570  rewriter.getIndexAttr(0),
1571  SparseTensorSortKind::HybridQuickSort);
1572  rewriter.setInsertionPointAfter(ifOp);
1573  }
1574 
1575  // Set PosMemRef0[1] = nse.
1576  const Value c1 = constantIndex(rewriter, loc, 1);
1577  const Value posMemref0 = desc.getPosMemRef(0);
1578  const Type posTp = dstTp.getPosType();
1579  const Value posNse = genCast(rewriter, loc, nse, posTp);
1580  rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1581 
1582  // Update storage specifier.
1583  Value coordinatesSize = rewriter.create<arith::MulIOp>(
1584  loc, nse, constantIndex(rewriter, loc, lvlRank));
1585  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1586  coordinatesSize);
1587  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1588  std::nullopt, nse);
1589 
1590  // Release the sparse tensor reader.
1591  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1592  EmitCInterface::Off);
1593 
1594  // Replace operation with resulting memrefs.
1595  rewriter.replaceOpWithMultiple(op, {fields});
1596  return success();
1597  }
1598 };
1599 
1600 struct SparseHasRuntimeLibraryConverter
1601  : public OpConversionPattern<HasRuntimeLibraryOp> {
1603  LogicalResult
1604  matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1605  ConversionPatternRewriter &rewriter) const override {
1606  auto i1Type = rewriter.getI1Type();
1607  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1608  op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1609  return success();
1610  }
1611 };
1612 
1613 } // namespace
1614 
1615 //===----------------------------------------------------------------------===//
1616 // Public method for populating conversion rules.
1617 //===----------------------------------------------------------------------===//
1618 
1619 /// Populates the given patterns list with conversion rules required for
1620 /// the sparsification of linear algebra operations.
1622  const TypeConverter &typeConverter, RewritePatternSet &patterns,
1623  bool createSparseDeallocs, bool enableBufferInitialization) {
1624  patterns.add<
1625  SparseAssembleOpConverter, SparseDisassembleOpConverter,
1626  SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1627  SparseCastConverter, SparseExtractSliceConverter,
1628  SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1629  SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1630  SparseSliceGetterOpConverter<ToSliceOffsetOp,
1631  StorageSpecifierKind::DimOffset>,
1632  SparseSliceGetterOpConverter<ToSliceStrideOp,
1633  StorageSpecifierKind::DimStride>,
1634  SparseToPositionsConverter, SparseToCoordinatesConverter,
1635  SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1636  SparseConvertConverter, SparseNewConverter,
1637  SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1638  typeConverter, patterns.getContext());
1639  patterns.add<SparseTensorDeallocConverter>(
1640  typeConverter, patterns.getContext(), createSparseDeallocs);
1641  patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1642  typeConverter, patterns.getContext(), enableBufferInitialization);
1643 }
static void flattenOperands(ValueRange operands, SmallVectorImpl< Value > &flattened)
Flattens a list of operands that may contain sparse tensors.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A helper class to simplify lowering operations with/without function calls.
Definition: CodegenUtils.h:78
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
bool isCompressedLT(LevelType lt)
Definition: Enums.h:415
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition: Enums.h:418
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields, RankedTensorType type)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor, RankedTensorType type)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326