MLIR  20.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library (other than for file I/O), and providing many more
14 // opportunities for subsequent compiler optimization of the generated code.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "Utils/CodegenUtils.h"
20 
32 
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Flatten the given value ranges into a single vector of values.
44  SmallVector<Value> result;
45  for (const auto &vals : values)
46  llvm::append_range(result, vals);
47  return result;
48 }
49 
50 /// Assert that the given value range contains a single value and return it.
52  assert(values.size() == 1 && "expected single value");
53  return values.front();
54 }
55 
56 /// Generates a load with proper `index` typing.
57 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
58  idx = genCast(builder, loc, idx, builder.getIndexType());
59  return builder.create<memref::LoadOp>(loc, mem, idx);
60 }
61 
62 /// Generates a store with proper `index` typing and proper value.
63 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
64  Value idx) {
65  idx = genCast(builder, loc, idx, builder.getIndexType());
66  val = genCast(builder, loc, val,
67  cast<ShapedType>(mem.getType()).getElementType());
68  builder.create<memref::StoreOp>(loc, val, mem, idx);
69 }
70 
71 /// Creates a straightforward counting for-loop.
72 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
74  Value lower = Value()) {
75  Type indexType = builder.getIndexType();
76  if (!lower)
77  lower = constantZero(builder, loc, indexType);
78  Value one = constantOne(builder, loc, indexType);
79  scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
80  for (unsigned i = 0, e = fields.size(); i < e; i++)
81  fields[i] = forOp.getRegionIterArg(i);
82  builder.setInsertionPointToStart(forOp.getBody());
83  return forOp;
84 }
85 
86 /// Creates a push back operation.
87 static void createPushback(OpBuilder &builder, Location loc,
89  SparseTensorFieldKind kind, std::optional<Level> lvl,
90  Value value, Value repeat = Value()) {
91  Type etp = desc.getMemRefElementType(kind, lvl);
92  Value field = desc.getMemRefField(kind, lvl);
93  StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
94 
95  auto pushBackOp = builder.create<PushBackOp>(
96  loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
97  genCast(builder, loc, value, etp), repeat);
98 
99  desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
100  desc.setSpecifierField(builder, loc, specFieldKind, lvl,
101  pushBackOp.getNewSize());
102 }
103 
104 /// Generates code that allocates a sparse storage scheme for given rank.
105 static void allocSchemeForRank(OpBuilder &builder, Location loc,
106  MutSparseTensorDescriptor desc, Level startLvl) {
107  const SparseTensorType stt(desc.getRankedTensorType());
108  Value linear = constantIndex(builder, loc, 1);
109  const Level lvlRank = stt.getLvlRank();
110  for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
111  const auto lt = stt.getLvlType(lvl);
112  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
113  // Append linear x positions, initialized to zero. Since each compressed
114  // dimension initially already has a single zero entry, this maintains
115  // the desired "linear + 1" length property at all times. For loose
116  // compression, we multiply linear by two in order to append both the
117  // lo/hi positions.
118  Value posZero = constantZero(builder, loc, stt.getPosType());
119  if (isLooseCompressedLT(lt)) {
120  Value two = constantIndex(builder, loc, 2);
121  linear = builder.create<arith::MulIOp>(loc, linear, two);
122  }
123  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
124  /*value=*/posZero, /*repeat=*/linear);
125  return;
126  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
127  return; // nothing to do
128  }
129  // Keep compounding the size, but nothing needs to be initialized
130  // at this level. We will eventually reach a compressed level or
131  // otherwise the values array for the from-here "all-dense" case.
132  assert(isDenseLT(lt));
133  Value size = desc.getLvlSize(builder, loc, lvl);
134  linear = builder.create<arith::MulIOp>(loc, linear, size);
135  }
136  // Reached values array so prepare for an insertion.
137  Value valZero = constantZero(builder, loc, stt.getElementType());
139  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
140 }
141 
142 /// Creates allocation operation.
144  MemRefType memRefType, Value sz,
145  bool enableInit) {
146  Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
147  Type elemType = memRefType.getElementType();
148  if (enableInit) {
149  Value fillValue = constantZero(builder, loc, elemType);
150  builder.create<linalg::FillOp>(loc, fillValue, buffer);
151  }
152  return buffer;
153 }
154 
155 /// Creates the dim sizes array, filling in from dynamic sizes.
156 static void createDimSizes(OpBuilder &builder, Location loc,
157  SparseTensorType stt, ValueRange dynSizes,
158  /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
159  const Dimension dimRank = stt.getDimRank();
160  dimSizesValues.clear();
161  dimSizesValues.reserve(dimRank);
162  unsigned i = 0;
163  for (const Size sz : stt.getDimShape())
164  dimSizesValues.push_back(ShapedType::isDynamic(sz)
165  ? dynSizes[i++]
166  : constantIndex(builder, loc, sz));
167 }
168 
169 /// Creates allocation for each field in sparse tensor type. Note that
170 /// for all dynamic memrefs in the sparse tensor stroage layout, the
171 /// memory size is really the capacity of the "vector", while the actual
172 /// size resides in the sizes array.
173 static void createAllocFields(OpBuilder &builder, Location loc,
174  SparseTensorType stt, bool enableInit,
175  Value sizeHint,
176  SmallVectorImpl<Value> &lvlSizesValues,
177  /*out*/ SmallVectorImpl<Value> &fields) {
178  Level lvlRank = stt.getLvlRank();
179  // Set up some heuristic sizes. We try to set the initial
180  // size based on available information. Otherwise we just
181  // initialize a few elements to start the reallocation chain.
182  // TODO: refine this
183  Value posHeuristic, crdHeuristic, valHeuristic;
184  if (stt.isAllDense()) {
185  valHeuristic = lvlSizesValues[0];
186  for (Level lvl = 1; lvl < lvlRank; lvl++)
187  valHeuristic =
188  builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
189  } else if (sizeHint) {
190  if (stt.getAoSCOOStart() == 0) {
191  posHeuristic = constantIndex(builder, loc, 2);
192  crdHeuristic = builder.create<arith::MulIOp>(
193  loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
194  } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
195  posHeuristic = builder.create<arith::AddIOp>(
196  loc, sizeHint, constantIndex(builder, loc, 1));
197  crdHeuristic = sizeHint;
198  } else {
199  posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
200  }
201  valHeuristic = sizeHint;
202  } else {
203  posHeuristic = crdHeuristic = valHeuristic =
204  constantIndex(builder, loc, 16);
205  }
206  // Initializes all fields. An initial storage specifier and allocated
207  // positions/coordinates/values memrefs (with heuristic capacity).
209  stt,
210  [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
211  enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
212  Level /*lvl*/, LevelType /*lt*/) -> bool {
213  assert(fields.size() == fIdx);
214  Value field;
215  switch (fKind) {
217  field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
218  break;
220  field = createAllocation(builder, loc, cast<MemRefType>(fType),
221  posHeuristic, enableInit);
222  break;
224  field = createAllocation(builder, loc, cast<MemRefType>(fType),
225  crdHeuristic, enableInit);
226  break;
228  field = createAllocation(builder, loc, cast<MemRefType>(fType),
229  valHeuristic, enableInit);
230  break;
231  }
232  assert(field);
233  fields.push_back(field);
234  // Returns true to continue the iteration.
235  return true;
236  });
237  // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
238  // and gives all position fields an initial zero entry, so that it is
239  // easier to maintain the "linear + 1" length property.
240  MutSparseTensorDescriptor desc(stt, fields);
241  Value posZero = constantZero(builder, loc, stt.getPosType());
242  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
243  desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
244  const auto lt = stt.getLvlType(lvl);
245  if (isCompressedLT(lt) || isLooseCompressedLT(lt))
246  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
247  /*value=*/posZero);
248  }
249  allocSchemeForRank(builder, loc, desc, /*rank=*/0);
250 }
251 
252 /// Helper method that generates block specific to compressed case:
253 ///
254 /// // given: parentPos = posCursor[lvl-1]
255 /// pstart = desc.positions[lvl][parentPos]
256 /// pstop = desc.positions[lvl][parentPos+1]
257 /// plast = pstop - 1
258 /// msz = desc.coordinates[lvl].size()
259 /// if (pstart < pstop) {
260 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
261 /// } else { // first insertion
262 /// isPresent = false
263 /// desc.positions[lvl][parentPos] = msz
264 /// }
265 /// if (isPresent) { // coordinate is already present
266 /// pnext = plast
267 /// } else {
268 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
269 /// desc.positions[lvl][parentPos+1] = msz+1
270 /// pnext = msz
271 /// <prepare level lvl+1>
272 /// }
273 /// posCursor[lvl] = pnext
274 static Value genCompressed(OpBuilder &builder, Location loc,
275  MutSparseTensorDescriptor desc, ValueRange lvlCoords,
276  Value /*unused*/, Value parentPos, Level lvl) {
277  const SparseTensorType stt(desc.getRankedTensorType());
278  const Level lvlRank = stt.getLvlRank();
279  assert(lvl < lvlRank && "Level is out of bounds");
280  assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
281  "Level-rank mismatch");
282  SmallVector<Type> types;
283  Type indexType = builder.getIndexType();
284  Type boolType = builder.getIntegerType(1);
285  unsigned crdFidx;
286  unsigned crdStride;
287  std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
288  const Value one = constantIndex(builder, loc, 1);
289  const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
290  const Value positionsAtLvl = desc.getPosMemRef(lvl);
291  const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
292  const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
293  const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
294  const Value crdStrideC =
295  crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
296  const Value msz =
297  crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
298  : crdMsz;
299  const Value plast = builder.create<arith::SubIOp>(
300  loc, genCast(builder, loc, pstop, indexType), one);
301  // Conditional expression.
302  Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
303  pstart, pstop);
304  types.push_back(boolType);
305  scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
306  types.pop_back();
307  builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
308  Value crd =
309  genLoad(builder, loc, desc.getMemRefField(crdFidx),
310  crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
311  : plast);
312  Value eq = builder.create<arith::CmpIOp>(
313  loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
314  lvlCoords[lvl]);
315  builder.create<scf::YieldOp>(loc, eq);
316  builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
317  if (lvl > 0)
318  genStore(builder, loc, msz, positionsAtLvl, parentPos);
319  builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
320  builder.setInsertionPointAfter(ifOp1);
321  // If present construct. Note that for a non-unique dimension level, we
322  // simply set the condition to false and rely on CSE/DCE to clean up the IR.
323  //
324  // TODO: generate less temporary IR?
325  //
326  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
327  types.push_back(desc.getField(i).getType());
328  types.push_back(indexType);
329  const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
330  : constantI1(builder, loc, false);
331  scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
332  // If present (fields unaffected, update pnext to plast).
333  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
334 
335  // FIXME: This does not looks like a clean way, but probably the most
336  // efficient way.
337  desc.getFields().push_back(plast);
338  builder.create<scf::YieldOp>(loc, desc.getFields());
339  desc.getFields().pop_back();
340 
341  // If !present (changes fields, update pnext).
342  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
343  Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
344  genStore(builder, loc, mszp1, positionsAtLvl, pp1);
345  createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
346  /*value=*/lvlCoords[lvl]);
347  // Prepare the next level "as needed".
348  if ((lvl + 1) < lvlRank)
349  allocSchemeForRank(builder, loc, desc, lvl + 1);
350 
351  desc.getFields().push_back(msz);
352  builder.create<scf::YieldOp>(loc, desc.getFields());
353  desc.getFields().pop_back();
354 
355  // Update fields and return next pos.
356  builder.setInsertionPointAfter(ifOp2);
357  unsigned o = 0;
358  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
359  desc.setField(i, ifOp2.getResult(o++));
360  return ifOp2.getResult(o);
361 }
362 
363 /// Generates insertion finalization code.
364 static void genEndInsert(OpBuilder &builder, Location loc,
365  SparseTensorDescriptor desc) {
366  const SparseTensorType stt(desc.getRankedTensorType());
367  const Level lvlRank = stt.getLvlRank();
368  for (Level lvl = 0; lvl < lvlRank; lvl++) {
369  const auto lt = stt.getLvlType(lvl);
370  if (isCompressedLT(lt)) {
371  // Compressed dimensions need a position cleanup for all entries
372  // that were not visited during the insertion pass.
373  //
374  // TODO: avoid cleanup and keep compressed scheme consistent at all
375  // times?
376  //
377  if (lvl > 0) {
378  Type posType = stt.getPosType();
379  Value posMemRef = desc.getPosMemRef(lvl);
380  Value hi = desc.getPosMemSize(builder, loc, lvl);
381  Value zero = constantIndex(builder, loc, 0);
382  Value one = constantIndex(builder, loc, 1);
383  // Vector of only one, but needed by createFor's prototype.
384  SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
385  scf::ForOp loop = createFor(builder, loc, hi, inits, one);
386  Value i = loop.getInductionVar();
387  Value oldv = loop.getRegionIterArg(0);
388  Value newv = genLoad(builder, loc, posMemRef, i);
389  Value posZero = constantZero(builder, loc, posType);
390  Value cond = builder.create<arith::CmpIOp>(
391  loc, arith::CmpIPredicate::eq, newv, posZero);
392  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
393  cond, /*else*/ true);
394  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
395  genStore(builder, loc, oldv, posMemRef, i);
396  builder.create<scf::YieldOp>(loc, oldv);
397  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
398  builder.create<scf::YieldOp>(loc, newv);
399  builder.setInsertionPointAfter(ifOp);
400  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
401  builder.setInsertionPointAfter(loop);
402  }
403  } else {
404  assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
405  isNOutOfMLT(lt));
406  }
407  }
408 }
409 
410 /// Generates a subview into the sizes.
411 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
412  Value sz) {
413  auto memTp = llvm::cast<MemRefType>(mem.getType());
414  // For higher-dimensional memrefs, we assume that the innermost
415  // dimension is always of the right size.
416  // TODO: generate complex truncating view here too?
417  if (memTp.getRank() > 1)
418  return mem;
419  // Truncate linear memrefs to given size.
420  return builder
421  .create<memref::SubViewOp>(
422  loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
423  mem, ValueRange{}, ValueRange{sz}, ValueRange{},
424  ArrayRef<int64_t>{0}, // static offset
425  ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
426  ArrayRef<int64_t>{1}) // static stride
427  .getResult();
428 }
429 
430 /// Creates the reassociation array.
432 getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
433  SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
434  // Create reassociation in the form:
435  // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
436  for (unsigned i = 0; i < batchLvls; i++)
437  ret[i].push_back(i);
438 
439  for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
440  ret.back().push_back(i);
441 
442  return ret;
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // Codegen rules.
447 //===----------------------------------------------------------------------===//
448 
449 namespace {
450 
451 /// Helper class to help lowering sparse_tensor.insert operation.
452 class SparseInsertGenerator
453  : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
454 public:
455  SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
456  bool genCall)
457  : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
458 
459  /// Generates code along an insertion path without the need for a "cursor".
460  /// This current insertion strategy comes at the expense of some testing
461  /// overhead for each insertion. The strategy will be optimized later for
462  /// common insertion patterns. The current insertion strategy also assumes
463  /// insertions occur in "a reasonable order" that enables building the
464  /// storage scheme in an appending/inserting kind of fashion (i.e. no
465  /// in-between insertions that need data movement). The implementation
466  /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
467  ///
468  /// TODO: better unord/not-unique; also generalize, optimize, specialize!
469  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
470  OpBuilder &builder, Location loc) {
471  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
472  const Level lvlRank = stt.getLvlRank();
473  // Extract fields and coordinates from args.
474  SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
475  MutSparseTensorDescriptor desc(stt, fields);
476  const SmallVector<Value> coords =
477  llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
478  Value value = args.back();
479  Value parentPos = constantZero(builder, loc, builder.getIndexType());
480  // Generate code for every level.
481  for (Level lvl = 0; lvl < lvlRank; lvl++) {
482  const auto lt = stt.getLvlType(lvl);
483  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
484  // Create:
485  // if (!present) {
486  // coordinates[lvl].push_back(coords[lvl])
487  // <update positions and prepare level lvl + 1>
488  // }
489  // positions[lvl] = coordinates.size() - 1
490  // <insert @ positions[lvl] at next level lvl + 1>
491  if (isLooseCompressedLT(lt)) {
492  Value two = constantIndex(builder, loc, 2);
493  parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
494  }
495  parentPos =
496  genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
497  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
498  // Create:
499  // coordinates[lvl].push_back(coords[lvl])
500  // positions[lvl] = positions[lvl-1]
501  // <insert @ positions[lvl] at next level lvl + 1>
503  lvl, /*value=*/coords[lvl]);
504  } else {
505  assert(isDenseLT(lt));
506  // Construct the new position as:
507  // positions[lvl] = size * positions[lvl-1] + coords[lvl]
508  // <insert @ positions[lvl] at next level lvl + 1>
509  Value size = desc.getLvlSize(builder, loc, lvl);
510  Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
511  parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
512  }
513  }
514  // Reached the actual value append/insert.
515  if (!stt.isDenseLvl(lvlRank - 1))
517  std::nullopt, value);
518  else
519  genStore(builder, loc, value, desc.getValMemRef(), parentPos);
520  return fields;
521  }
522 
523  std::string getMangledFuncName() {
524  // The mangled name of the function has this format:
525  // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
526  constexpr const char kInsertFuncNamePrefix[] = "_insert_";
527  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
528  SmallString<32> nameBuffer;
529  llvm::raw_svector_ostream nameOstream(nameBuffer);
530  nameOstream << kInsertFuncNamePrefix;
531  const Level lvlRank = stt.getLvlRank();
532  for (Level l = 0; l < lvlRank; l++) {
533  std::string lvlType = toMLIRString(stt.getLvlType(l));
534  // Replace/remove punctuations in level properties.
535  std::replace_if(
536  lvlType.begin(), lvlType.end(),
537  [](char c) { return c == '(' || c == ','; }, '_');
538  llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
539  nameOstream << lvlType << "_";
540  }
541  // Static dim sizes are used in the generated code while dynamic sizes are
542  // loaded from the dimSizes buffer. This is the reason for adding the shape
543  // to the function name.
544  for (const auto sz : stt.getDimShape())
545  nameOstream << sz << "_";
546  // Permutation information is also used in generating insertion.
547  if (!stt.isIdentity())
548  nameOstream << stt.getDimToLvl() << "_";
549  nameOstream << stt.getElementType() << "_";
550  nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
551  return nameOstream.str().str();
552  }
553 
554 private:
555  TensorType rtp;
556 };
557 
558 /// Sparse tensor storage conversion rule for returns.
559 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
560 public:
562  LogicalResult
563  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
564  ConversionPatternRewriter &rewriter) const override {
565  // Create a return with the flattened value extracted from sparse tensors.
566  rewriter.replaceOpWithNewOp<func::ReturnOp>(
567  op, flattenValues(adaptor.getOperands()));
568  return success();
569  }
570 };
571 
572 /// Sparse tensor storage conversion rule for calls.
573 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
574 public:
575  // The default CallOp converter can not handle 1:N type conversion.
577  LogicalResult
578  matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
579  ConversionPatternRewriter &rewriter) const override {
580  Location loc = op.getLoc();
581  // In case of:
582  // sparse_tensor, f, sparse_tensor = call @foo(...)
583  // ==>
584  // memref..., f, memref = call @foo(...) replace with
585  // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
586  SmallVector<Type> finalRetTy;
587  if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
588  return failure();
589 
590  // (1) Generates new call with flattened return value.
591  auto newCall = rewriter.create<func::CallOp>(
592  loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
593  // (2) Gather sparse tensor returns.
594  SmallVector<SmallVector<Value>> packedResultVals;
595  // Tracks the offset of current return value (of the original call)
596  // relative to the new call (after sparse tensor flattening);
597  unsigned retOffset = 0;
598  // Temporal buffer to hold the flattened list of type for
599  // a sparse tensor.
600  SmallVector<Type> sparseFlat;
601  for (auto ret : op.getResults()) {
602  assert(retOffset < newCall.getNumResults());
603  auto retType = ret.getType();
604  if (failed(typeConverter->convertType(retType, sparseFlat)))
605  llvm_unreachable("Failed to convert type in sparse tensor codegen");
606 
607  // Converted types can not be empty when the type conversion succeed.
608  assert(!sparseFlat.empty());
609  if (sparseFlat.size() > 1) {
610  auto flatSize = sparseFlat.size();
611  packedResultVals.emplace_back();
612  llvm::append_range(packedResultVals.back(),
613  newCall.getResults().slice(retOffset, flatSize));
614  retOffset += flatSize;
615  } else {
616  // If this is an 1:1 conversion, no need for casting.
617  packedResultVals.emplace_back();
618  packedResultVals.back().push_back(newCall.getResult(retOffset));
619  retOffset++;
620  }
621  sparseFlat.clear();
622  }
623 
624  assert(packedResultVals.size() == op.getNumResults());
625  rewriter.replaceOpWithMultiple(
626  op, llvm::to_vector_of<ValueRange>(packedResultVals));
627  return success();
628  }
629 };
630 
631 /// Sparse codegen rule for level accesses.
632 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
633 public:
635  LogicalResult
636  matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
637  ConversionPatternRewriter &rewriter) const override {
638  std::optional<int64_t> lvl = op.getConstantLvlIndex();
639  RankedTensorType srcType = op.getSource().getType();
640  if (!lvl || !getSparseTensorEncoding(srcType))
641  return failure();
642 
643  auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
644  auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
645 
646  rewriter.replaceOp(op, sz);
647  return success();
648  }
649 };
650 
651 // TODO: use a new SortCOO operation here instead of reusing convert op.
652 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
654  LogicalResult
655  matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
656  ConversionPatternRewriter &rewriter) const override {
657  Location loc = op.getLoc();
658  MLIRContext *ctx = op.getContext();
659 
660  SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
661  SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
662 
663  // Should have been verified.
664  assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
665  dstStt.isCOOType() && srcStt.isCOOType());
666  assert(dstStt.hasSameDimToLvl(srcStt));
667 
668  // We don't need a mutable descriptor here as we perform sorting in-place.
669  auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
670  op.getInputCoo().getType());
671  auto nnz = desc.getValMemSize(rewriter, op.getLoc());
672  auto crd = desc.getAOSMemRef();
673  auto val = desc.getValMemRef();
674 
675  // Otherwise we need another data shuffle and a non-identity map.
676  assert(dstStt.hasSameDimToLvl(srcStt));
677  (void)dstStt; // to silence warning when assertion is disabled
678 
679  auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
680 
681  rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
682  rewriter.getIndexAttr(0), op.getAlgorithm());
683 
684  // Since we do in-place sorting, the destinate tensor will have the same set
685  // of memrefs as the source tensor.
686  rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
687  return success();
688  }
689 };
690 
691 template <typename Op, StorageSpecifierKind kind>
692 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
693 public:
696 
697  LogicalResult
698  matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
699  ConversionPatternRewriter &rewriter) const override {
700  // Simply lowers to specifer.get <field> operation.
701  auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
702  op.getSlice().getType());
703  auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
704  op.getDim().getZExtValue());
705 
706  rewriter.replaceOp(op, v);
707  return success();
708  }
709 };
710 
711 /// Sparse codegen rule for trivial tensor casts.
712 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
713 public:
715  LogicalResult
716  matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
717  ConversionPatternRewriter &rewriter) const override {
718  // Only rewrite identically annotated source/dest.
719  auto encDst = getSparseTensorEncoding(op.getType());
720  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
721  if (!encDst || encDst != encSrc)
722  return failure();
723  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
724  return success();
725  }
726 };
727 
728 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
729 public:
731  LogicalResult
732  matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
733  ConversionPatternRewriter &rewriter) const override {
734  // Simply fold the operation.
735  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
736  return success();
737  }
738 };
739 
740 /// Sparse codegen rule for the alloc operator.
741 class SparseTensorAllocConverter
742  : public OpConversionPattern<bufferization::AllocTensorOp> {
743 public:
745  SparseTensorAllocConverter(const TypeConverter &typeConverter,
746  MLIRContext *context, bool enableInit)
747  : OpConversionPattern(typeConverter, context),
748  enableBufferInitialization(enableInit) {}
749 
750  LogicalResult
751  matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
752  ConversionPatternRewriter &rewriter) const override {
753  const auto resType = getSparseTensorType(op);
754  if (!resType.hasEncoding())
755  return failure();
756 
757  Location loc = op.getLoc();
758  // Deal with copy.
759  if (op.getCopy()) {
760  auto desc = getDescriptorFromTensorTuple(
761  adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
762  SmallVector<Value> fields;
763  fields.reserve(desc.getNumFields());
764  // Memcpy on memref fields.
765  for (auto field : desc.getMemRefFields()) {
766  auto memrefTp = cast<MemRefType>(field.getType());
767  auto size = rewriter.create<memref::DimOp>(loc, field, 0);
768  auto copied =
769  rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
770  rewriter.create<memref::CopyOp>(loc, field, copied);
771  fields.push_back(copied);
772  }
773  // Reuses specifier.
774  fields.push_back(desc.getSpecifier());
775  assert(fields.size() == desc.getNumFields());
776  rewriter.replaceOpWithMultiple(op, {fields});
777  return success();
778  }
779 
780  if (!resType.isIdentity()) {
781  return rewriter.notifyMatchFailure(
782  op, "try run --sparse-reinterpret-map before codegen");
783  }
784  // Level size equals to dimension size since lvl2dim map is an identity map.
785  SmallVector<Value> lvlSizesValues;
786  createDimSizes(rewriter, loc, resType,
787  flattenValues(adaptor.getDynamicSizes()),
788  /*dimSizesValues=*/lvlSizesValues);
789 
790  // Construct allocation for each field.
791  Value sizeHint = op.getSizeHint();
792  SmallVector<Value> fields;
793  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
794  sizeHint, lvlSizesValues, fields);
795 
796  // Replace operation with resulting memrefs.
797  rewriter.replaceOpWithMultiple(op, {fields});
798  return success();
799  }
800 
801 private:
802  bool enableBufferInitialization;
803 };
804 
805 /// Sparse codegen rule for the empty tensor operator.
806 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
807 public:
809  SparseTensorEmptyConverter(const TypeConverter &typeConverter,
810  MLIRContext *context, bool enableInit)
811  : OpConversionPattern(typeConverter, context),
812  enableBufferInitialization(enableInit) {}
813 
814  LogicalResult
815  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
816  ConversionPatternRewriter &rewriter) const override {
817  const auto resType = getSparseTensorType(op);
818  if (!resType.hasEncoding())
819  return failure();
820 
821  if (!resType.isIdentity()) {
822  return rewriter.notifyMatchFailure(
823  op, "try run --sparse-reinterpret-map before codegen");
824  }
825 
826  Location loc = op.getLoc();
827  // Level size equals to dimension size since lvl2dim map is an identity map.
828  SmallVector<Value> lvlSizesValues;
829  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
830  /*dimSizesValues=*/lvlSizesValues);
831  // Construct allocation for each field.
832  Value sizeHint; // none
833  SmallVector<Value> fields;
834  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
835  sizeHint, lvlSizesValues, fields);
836 
837  // Replace operation with resulting memrefs.
838  rewriter.replaceOpWithMultiple(op, {fields});
839  return success();
840  }
841 
842 private:
843  bool enableBufferInitialization;
844 };
845 
846 /// Sparse codegen rule for the dealloc operator.
847 class SparseTensorDeallocConverter
848  : public OpConversionPattern<bufferization::DeallocTensorOp> {
849 public:
851  SparseTensorDeallocConverter(const TypeConverter &typeConverter,
852  MLIRContext *context, bool createDeallocs)
853  : OpConversionPattern(typeConverter, context),
854  createDeallocs(createDeallocs) {}
855 
856  LogicalResult
857  matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
858  ConversionPatternRewriter &rewriter) const override {
859  auto enc = getSparseTensorEncoding(op.getTensor().getType());
860  if (!enc)
861  return failure();
862 
863  // If user requests not to deallocate sparse tensors, simply erase the
864  // operation.
865  if (createDeallocs) {
866  // Replace the sparse tensor deallocation with field deallocations.
867  Location loc = op.getLoc();
868  auto desc = getDescriptorFromTensorTuple(
869  adaptor.getTensor(),
870  cast<RankedTensorType>(op.getTensor().getType()));
871  for (auto input : desc.getMemRefFields())
872  // Deallocate every buffer used to store the sparse tensor handler.
873  rewriter.create<memref::DeallocOp>(loc, input);
874  }
875  rewriter.eraseOp(op);
876  return success();
877  }
878 
879 private:
880  const bool createDeallocs;
881 };
882 
883 /// Sparse codegen rule for tensor rematerialization.
884 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
885 public:
887  LogicalResult
888  matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
889  ConversionPatternRewriter &rewriter) const override {
890  // Prepare descriptor.
891  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
892  op.getTensor().getType());
893  // Generate optional insertion finalization code.
894  if (op.getHasInserts())
895  genEndInsert(rewriter, op.getLoc(), desc);
896  // Replace operation with resulting memrefs.
897  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
898  return success();
899  }
900 };
901 
902 /// Sparse codegen rule for the expand op.
903 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
904 public:
906  LogicalResult
907  matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
908  ConversionPatternRewriter &rewriter) const override {
909  if (!getSparseTensorEncoding(op.getTensor().getType()))
910  return failure();
911  Location loc = op->getLoc();
912  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
913  op.getTensor().getType());
914  const auto srcType = getSparseTensorType(op.getTensor());
915  Type eltType = srcType.getElementType();
916  Type boolType = rewriter.getIntegerType(1);
917  Type idxType = rewriter.getIndexType();
918  // All initialization should be done on entry of the loop nest.
919  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
920 
921  // Determine the size for access expansion (always the innermost stored
922  // level size).
923  const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
924  // Generate a memref for `sz` elements of type `t`.
925  const auto genAlloc = [&](Type t) {
926  const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
927  return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
928  };
929  // Allocate temporary buffers for values/filled-switch and added.
930  // We do not use stack buffers for this, since the expanded size may
931  // be rather large (as it envelops a single expanded dense dimension).
932  Value values = genAlloc(eltType);
933  Value filled = genAlloc(boolType);
934  Value added = genAlloc(idxType);
935  Value zero = constantZero(rewriter, loc, idxType);
936  // Reset the values/filled-switch to all-zero/false. Note that this
937  // introduces an O(N) operation into the computation, but this reset
938  // operation is amortized over the innermost loops for the access
939  // pattern expansion. As noted in the operation doc, we would like
940  // to amortize this setup cost even between kernels.
941  rewriter.create<linalg::FillOp>(
942  loc, ValueRange{constantZero(rewriter, loc, eltType)},
943  ValueRange{values});
944  rewriter.create<linalg::FillOp>(
945  loc, ValueRange{constantZero(rewriter, loc, boolType)},
946  ValueRange{filled});
947  // Replace expansion op with these buffers and initial coordinate.
948  assert(op.getNumResults() == 4);
949  rewriter.replaceOp(op, {values, filled, added, zero});
950  return success();
951  }
952 };
953 
954 /// Sparse codegen rule for the compress operator.
955 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
956 public:
958  LogicalResult
959  matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
960  ConversionPatternRewriter &rewriter) const override {
961  Location loc = op->getLoc();
962  SmallVector<Value> fields;
963  auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
964  op.getTensor().getType());
965  Value values = getSingleValue(adaptor.getValues());
966  Value filled = getSingleValue(adaptor.getFilled());
967  Value added = getSingleValue(adaptor.getAdded());
968  Value count = getSingleValue(adaptor.getCount());
969  const SparseTensorType dstType(desc.getRankedTensorType());
970  Type eltType = dstType.getElementType();
971 
972  // If the innermost level is ordered, we need to sort the coordinates
973  // in the "added" array prior to applying the compression.
974  if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
975  rewriter.create<SortOp>(
976  loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
977  rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
978  // While performing the insertions, we also need to reset the elements
979  // of the values/filled-switch by only iterating over the set elements,
980  // to ensure that the runtime complexity remains proportional to the
981  // sparsity of the expanded access pattern.
982  //
983  // Generate
984  // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
985  // crd = added[i];
986  // value = values[crd];
987  // insert({lvlCoords, crd}, value);
988  // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
989  // values[crd] = 0;
990  // filled[crd] = false;
991  // yield new_memrefs
992  // }
993  scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
994  Value i = loop.getInductionVar();
995 
996  Value crd = genLoad(rewriter, loc, added, i);
997  Value value = genLoad(rewriter, loc, values, crd);
998  SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
999  SmallVector<Type> flatSpTensorTps = llvm::to_vector(
1000  llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
1001  SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
1002  params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1003  params.push_back(crd);
1004  params.push_back(value);
1005  SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1006  params, /*genCall=*/true);
1007  SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1008  genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1009  genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1010  rewriter.create<scf::YieldOp>(loc, insertRet);
1011 
1012  rewriter.setInsertionPointAfter(loop);
1013  // Deallocate the buffers on exit of the full loop nest.
1014  Operation *parent = getTop(op);
1015  rewriter.setInsertionPointAfter(parent);
1016  rewriter.create<memref::DeallocOp>(loc, values);
1017  rewriter.create<memref::DeallocOp>(loc, filled);
1018  rewriter.create<memref::DeallocOp>(loc, added);
1019  // Replace operation with resulting memrefs.
1020  rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1021  return success();
1022  }
1023 };
1024 
1025 /// Sparse codegen rule for the insert operator.
1026 class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1027 public:
1029  LogicalResult
1030  matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1031  ConversionPatternRewriter &rewriter) const override {
1032  auto stt = getSparseTensorType(op.getDest());
1033  if (!stt.hasEncoding())
1034  return failure();
1035  assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1036 
1037  Location loc = op.getLoc();
1038  auto desc =
1039  getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1040  TypeRange flatSpTensorTps = desc.getFields().getTypes();
1041  SmallVector<Value> params = llvm::to_vector(desc.getFields());
1042  SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1043  params.append(flatIndices.begin(), flatIndices.end());
1044  params.push_back(getSingleValue(adaptor.getScalar()));
1045  SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1046  params, /*genCall=*/true);
1047  SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1048  // Replace operation with resulting memrefs.
1049  rewriter.replaceOpWithMultiple(op, {ret});
1050  return success();
1051  }
1052 };
1053 
1054 /// Sparse codegen rule for position accesses.
1055 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1056 public:
1057  using OpAdaptor = typename ToPositionsOp::Adaptor;
1059  LogicalResult
1060  matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1061  ConversionPatternRewriter &rewriter) const override {
1062  // Replace the requested position access with corresponding field.
1063  // The view is restricted to the actual size to ensure clients
1064  // of this operation truly observe size, not capacity!
1065  Location loc = op.getLoc();
1066  Level lvl = op.getLevel();
1067  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1068  op.getTensor().getType());
1069  auto mem = desc.getPosMemRef(lvl);
1070  auto size = desc.getPosMemSize(rewriter, loc, lvl);
1071  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1072  return success();
1073  }
1074 };
1075 
1076 /// Sparse codegen rule for accessing the coordinates arrays.
1077 class SparseToCoordinatesConverter
1078  : public OpConversionPattern<ToCoordinatesOp> {
1079 public:
1080  using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1082  LogicalResult
1083  matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1084  ConversionPatternRewriter &rewriter) const override {
1085  // Replace the requested coordinates access with corresponding field.
1086  // The view is restricted to the actual size to ensure clients
1087  // of this operation truly observe size, not capacity!
1088  Location loc = op.getLoc();
1089  Level lvl = op.getLevel();
1090  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1091  op.getTensor().getType());
1092  auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1093  if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1094  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1095  mem = genSliceToSize(rewriter, loc, mem, size);
1096  }
1097  rewriter.replaceOp(op, mem);
1098  return success();
1099  }
1100 };
1101 
1102 /// Sparse codegen rule for accessing the linear coordinates buffer.
1103 class SparseToCoordinatesBufferConverter
1104  : public OpConversionPattern<ToCoordinatesBufferOp> {
1105 public:
1106  using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1108  LogicalResult
1109  matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1110  ConversionPatternRewriter &rewriter) const override {
1111  // Replace the requested coordinates access with corresponding field.
1112  // The view is restricted to the actual size to ensure clients
1113  // of this operation truly observe size, not capacity!
1114  Location loc = op.getLoc();
1115  Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1116  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1117  op.getTensor().getType());
1118  auto mem = desc.getAOSMemRef();
1119  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1120  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1121  return success();
1122  }
1123 };
1124 
1125 /// Sparse codegen rule for value accesses.
1126 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1127 public:
1128  using OpAdaptor = typename ToValuesOp::Adaptor;
1130  LogicalResult
1131  matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1132  ConversionPatternRewriter &rewriter) const override {
1133  // Replace the requested values access with corresponding field.
1134  // The view is restricted to the actual size to ensure clients
1135  // of this operation truly observe size, not capacity!
1136  Location loc = op.getLoc();
1137  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1138  op.getTensor().getType());
1139  auto mem = desc.getValMemRef();
1140  auto size = desc.getValMemSize(rewriter, loc);
1141  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1142  return success();
1143  }
1144 };
1145 
1146 /// Sparse codegen rule for the convert operator.
1147 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1148 public:
1150  LogicalResult
1151  matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1152  ConversionPatternRewriter &rewriter) const override {
1153  SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1154  SparseTensorEncodingAttr encSrc =
1155  getSparseTensorEncoding(op.getSource().getType());
1156  // The output tensor can not be a slice and those cases should have been
1157  // rejected by ConvertOp::verify() already.
1158  assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1159  // Different encoding (except for different bitwidth) should be handled by
1160  // rewriting.
1161  // We need further rewrites if the input tensor is a slice too.
1162  if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1163  encSrc.isSlice()) {
1164  return failure();
1165  }
1166 
1167  Type retElemTp = op.getResult().getType().getElementType();
1168  Type srcElemTp = op.getSource().getType().getElementType();
1169  // Fold the trivial cases.
1170  if (retElemTp == srcElemTp && encDst == encSrc) {
1171  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1172  return success();
1173  }
1174  //
1175  // Do element-wise type conversion without using InsertOp.
1176  //
1177  // for each memref in srcTensor:
1178  // dst = memref.alloc
1179  // if srcMemRefType != dstMemRefType:
1180  // for every dst[i] = cast(src[i])
1181  // else:
1182  // dst = memref.copy(src)
1183  Location loc = op.getLoc();
1184  auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1185  op.getSource().getType());
1186  SmallVector<Value> fields;
1188  SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1189  [&rewriter, &fields, srcDesc,
1190  loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1191  LevelType /*lt*/) -> bool {
1192  // Simply reuses the storage specifier as it is an SSA value.
1193  if (fKind == SparseTensorFieldKind::StorageSpec) {
1194  fields.push_back(srcDesc.getSpecifier());
1195  } else {
1196  // Allocates new memrefs
1197  Value srcMem = srcDesc.getMemRefField(fIdx);
1198  // TODO: We can instead use the actual memSize in specifier, that
1199  // would require a subViewOp to avoid overflow when copying
1200  // values.
1201  Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1202  auto dstMem = rewriter.create<memref::AllocOp>(
1203  loc, cast<MemRefType>(fTp), sz);
1204  if (fTp != srcMem.getType()) {
1205  // Converts elements type.
1206  scf::buildLoopNest(
1207  rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1208  constantIndex(rewriter, loc, 1),
1209  [srcMem, &dstMem](OpBuilder &builder, Location loc,
1210  ValueRange ivs) {
1211  Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1212  Value casted = genCast(builder, loc, v,
1213  dstMem.getType().getElementType());
1214  builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1215  });
1216  } else {
1217  // TODO: We can even reuse the same memref for the new tensor,
1218  // but that requires a `ref-counting` based memory management
1219  // for shared memrefs between multiple sparse tensors.
1220  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1221  }
1222  fields.push_back(dstMem);
1223  }
1224  return true;
1225  });
1226 
1227  rewriter.replaceOpWithMultiple(op, {fields});
1228  return success();
1229  }
1230 };
1231 
1232 class SparseExtractSliceConverter
1233  : public OpConversionPattern<tensor::ExtractSliceOp> {
1234 public:
1236  LogicalResult
1237  matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1238  ConversionPatternRewriter &rewriter) const override {
1239  Location loc = op.getLoc();
1240  MLIRContext *ctx = op.getContext();
1241  auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1242  auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1243  // TODO: We should check these in ExtractSliceOp::verify.
1244  if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1245  return failure();
1246  assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1247 
1248  SmallVector<Value> fields;
1249  auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1250  op.getSource().getType());
1251 
1252  auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1253  loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1254  desc.setSpecifier(newSpec);
1255 
1256  // Fills in slice information.
1257  for (auto [idx, offset, size, stride] : llvm::enumerate(
1258  op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1259  Dimension dim = idx;
1260 
1261  Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1262  Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1263  Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1264  // TODO: We could probably only set dynamic value here. But it would
1265  // requires us to fill the hole when casting a static slice to dynamic
1266  // slice.
1267  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1268  dim, offsetV);
1269 
1270  // FIXME: we need to distinguish level sizes and dimension size for slices
1271  // here. Maybe we should store slice level sizes in a different array
1272  // instead of reusing it.
1273  assert(srcEnc.isIdentity());
1274  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1275  sizeV);
1276  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1277  dim, strideV);
1278  }
1279 
1280  // NOTE: we can not generate tuples directly from descriptor here, as the
1281  // descriptor is holding the original type, yet we want the slice type
1282  // here (they shared every memref but with an updated specifier).
1283  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1284  return success();
1285  }
1286 };
1287 
1288 /// Sparse codegen rule for number of entries operator.
1289 class SparseNumberOfEntriesConverter
1290  : public OpConversionPattern<NumberOfEntriesOp> {
1291 public:
1293  LogicalResult
1294  matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1295  ConversionPatternRewriter &rewriter) const override {
1296  // Query memSizes for the actually stored values.
1297  // FIXME: the nse value computed in this way might be wrong when there is
1298  // any "loose_compressed" level.
1299  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1300  op.getTensor().getType());
1301  rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1302  return success();
1303  }
1304 };
1305 
1306 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1308  LogicalResult
1309  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1310  ConversionPatternRewriter &rewriter) const override {
1311  Location loc = op.getLoc();
1312  const auto stt = getSparseTensorType(op.getResult());
1313 
1314  SmallVector<Value> fields;
1315 
1317  stt,
1318  [&rewriter, &fields, &op, &stt,
1319  loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1320  Level /*lvl*/, LevelType lt) -> bool {
1321  assert(fields.size() == fIdx);
1322  if (fKind == SparseTensorFieldKind::StorageSpec) {
1323  fields.push_back(
1324  SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1325  } else {
1326  // Else simply takes the inputs.
1327  Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1328  ? op.getValues()
1329  : op.getLevels()[fIdx];
1330  // TODO: handle batch.
1331  TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1332  if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1333  // Flattens the buffer to batchLvlRank.
1334  auto reassoc = getReassociationForFlattening(
1335  mem.getType(), stt.getBatchLvlRank());
1336  mem = rewriter.create<memref::CastOp>(
1337  loc, fType,
1338  rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1339  } else {
1340  mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1341  }
1342  fields.push_back(mem);
1343  }
1344  return true;
1345  });
1346 
1347  MutSparseTensorDescriptor desc(stt, fields);
1348  Value c0 = constantIndex(rewriter, loc, 0);
1349  Value c1 = constantIndex(rewriter, loc, 1);
1350  Value c2 = constantIndex(rewriter, loc, 2);
1351  Value posBack = c0; // index to the last value in the position array
1352  Value memSize = c1; // memory size for current array
1353 
1354  Level trailCOOStart = stt.getAoSCOOStart();
1355  Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1356  // Sets up SparseTensorSpecifier.
1357  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1358  assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1359 
1360  // Sets up the level size.
1361  auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1362  desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1363  // We use a single AOS array to store the trailing COO, so there is only
1364  // one memory size to set for the entire COO section.
1365  if (lvl > trailCOOStart)
1366  continue;
1367 
1368  // Sets up the memory size by reading the last value in position array.
1369  LevelType lt = stt.getLvlType(lvl);
1370  // Simply forwards the position index when this is a dense level.
1371  if (lt.isa<LevelFormat::Dense>()) {
1372  memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1373  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1374  continue;
1375  }
1376  if (lt.isa<LevelFormat::Batch>()) {
1377  // Skips batch levels as it is not linearized.
1378  // FIXME: this assumes that every batch has the same number of nse, need
1379  // to be generalized to handle varied-size batches.
1380  continue;
1381  }
1382 
1383  if (isWithPosLT(lt)) {
1384  assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1385  if (isLooseCompressedLT(lt)) {
1386  memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1387  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1388  } else {
1389  assert(isCompressedLT(lt));
1390  posBack = memSize;
1391  memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1392  }
1393  desc.setPosMemSize(rewriter, loc, lvl, memSize);
1394  // The last value in position array is the memory size for next level.
1395  // FIXME: this assumes that every batch has the same number of nse, need
1396  // to be generalized to handle varied-size batches.
1397  SmallVector<Value> batched(stt.getBatchLvlRank(),
1398  constantIndex(rewriter, loc, 0));
1399  batched.push_back(posBack);
1400  memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1401  posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1402  }
1403  assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1404  // FIXME: This seems to be unnecessarily complex, can we simplify it?
1405  if (lvl == trailCOOStart) {
1406  Value cooSz = rewriter.create<arith::MulIOp>(
1407  loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1408  desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1409  } else {
1410  desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1411  }
1412  }
1413  desc.setValMemSize(rewriter, loc, memSize);
1414 
1415  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1416  return success();
1417  }
1418 };
1419 
1420 struct SparseDisassembleOpConverter
1421  : public OpConversionPattern<DisassembleOp> {
1423  SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1424  MLIRContext *context)
1425  : OpConversionPattern(typeConverter, context) {}
1426 
1427  LogicalResult
1428  matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1429  ConversionPatternRewriter &rewriter) const override {
1430  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1431  op.getTensor().getType());
1432  Location loc = op.getLoc();
1433  SmallVector<Value> retMem;
1434  SmallVector<Value> retLen;
1435  desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1436  &retLen](FieldIndex fid,
1437  SparseTensorFieldKind fKind,
1438  Level lvl, LevelType lt) -> bool {
1439  if (fKind == SparseTensorFieldKind::StorageSpec)
1440  return true;
1442  Value sz, src;
1444  if (fKind == SparseTensorFieldKind::ValMemRef) {
1445  sz = desc.getValMemSize(rewriter, loc);
1446  src = desc.getValMemRef();
1447  dst = genToMemref(rewriter, loc, op.getOutValues());
1448 
1449  retMem.push_back(dst);
1450  Type valLenTp = op.getValLen().getType();
1451  retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1452  } else {
1453  assert(fKind == SparseTensorFieldKind::PosMemRef ||
1454  fKind == SparseTensorFieldKind::CrdMemRef);
1455 
1456  sz = fKind == SparseTensorFieldKind::PosMemRef
1457  ? desc.getPosMemSize(rewriter, loc, lvl)
1458  : desc.getCrdMemSize(rewriter, loc, lvl);
1459  src = desc.getMemRefField(fid);
1460  dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1461  retMem.push_back(dst);
1462  // Retrieves the corresponding level length type.
1463  Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1464  retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1465  }
1466  Value flatOut = dst;
1467  if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1468  auto reassoc =
1469  getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1470  flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1471  }
1472  Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1473  Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1474  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1475  return true;
1476  });
1477 
1478  // Converts MemRefs back to Tensors.
1479  SmallVector<Value> retValues = llvm::to_vector(
1480  llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1481  return rewriter.create<bufferization::ToTensorOp>(loc, v);
1482  }));
1483  // Appends the actual memory length used in each buffer returned.
1484  retValues.append(retLen.begin(), retLen.end());
1485  rewriter.replaceOp(op, retValues);
1486  return success();
1487  }
1488 };
1489 
1490 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1492  LogicalResult
1493  matchAndRewrite(NewOp op, OpAdaptor adaptor,
1494  ConversionPatternRewriter &rewriter) const override {
1495  Location loc = op.getLoc();
1496  const auto dstTp = getSparseTensorType(op.getResult());
1497  // Creating COO with NewOp is handled by direct IR codegen. All other cases
1498  // are handled by rewriting.
1499  if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1500  return failure();
1501 
1502  // Implement as follows:
1503  // %reader = @createCheckedSparseTensorReader(%filename)
1504  // %nse = @getSparseTensorNSE(%reader)
1505  // %coo = bufferization.alloc_tensor an ordered COO with
1506  // dst dim ordering, size_hint = %nse
1507  // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1508  // %values = sparse_tensor.values(%coo)
1509  // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1510  // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1511  // update storage specifier
1512  // @delSparseTensorReader(%reader)
1513  SmallVector<Value> dimSizesValues;
1514  Value dimSizesBuffer;
1515  Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1516  dimSizesValues, dimSizesBuffer);
1517 
1518  // Get the number of stored entries.
1519  const Type indexTp = rewriter.getIndexType();
1520  Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1521  {indexTp}, {reader}, EmitCInterface::Off)
1522  .getResult(0);
1523 
1524  // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1525  SmallVector<Value> lvlSizesValues;
1526  Value dim2lvlBuffer;
1527  Value lvl2dimBuffer;
1528  genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1529  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1530 
1531  // Construct allocation for each field.
1532  Value sizeHint = nse;
1533  SmallVector<Value> fields;
1534  createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1535  lvlSizesValues, fields);
1536 
1537  // Read the COO tensor data.
1538  MutSparseTensorDescriptor desc(dstTp, fields);
1539  Value xs = desc.getAOSMemRef();
1540  Value ys = desc.getValMemRef();
1541  const Type boolTp = rewriter.getIntegerType(1);
1542  const Type elemTp = dstTp.getElementType();
1543  const Type crdTp = dstTp.getCrdType();
1544  SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1546  primaryTypeFunctionSuffix(elemTp)};
1547  Value isSorted =
1548  createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1549  {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1550  EmitCInterface::On)
1551  .getResult(0);
1552 
1553  // If the destination tensor is a sorted COO, we need to sort the COO tensor
1554  // data if the input elements aren't sorted yet.
1555  const Level lvlRank = dstTp.getLvlRank();
1556  if (dstTp.isOrderedLvl(lvlRank - 1)) {
1557  Value kFalse = constantI1(rewriter, loc, false);
1558  Value notSorted = rewriter.create<arith::CmpIOp>(
1559  loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1560  scf::IfOp ifOp =
1561  rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1562  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1563  auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1564  rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1565  rewriter.getIndexAttr(0),
1566  SparseTensorSortKind::HybridQuickSort);
1567  rewriter.setInsertionPointAfter(ifOp);
1568  }
1569 
1570  // Set PosMemRef0[1] = nse.
1571  const Value c1 = constantIndex(rewriter, loc, 1);
1572  const Value posMemref0 = desc.getPosMemRef(0);
1573  const Type posTp = dstTp.getPosType();
1574  const Value posNse = genCast(rewriter, loc, nse, posTp);
1575  rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1576 
1577  // Update storage specifier.
1578  Value coordinatesSize = rewriter.create<arith::MulIOp>(
1579  loc, nse, constantIndex(rewriter, loc, lvlRank));
1580  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1581  coordinatesSize);
1582  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1583  std::nullopt, nse);
1584 
1585  // Release the sparse tensor reader.
1586  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1587  EmitCInterface::Off);
1588 
1589  // Replace operation with resulting memrefs.
1590  rewriter.replaceOpWithMultiple(op, {fields});
1591  return success();
1592  }
1593 };
1594 
1595 struct SparseHasRuntimeLibraryConverter
1596  : public OpConversionPattern<HasRuntimeLibraryOp> {
1598  LogicalResult
1599  matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1600  ConversionPatternRewriter &rewriter) const override {
1601  auto i1Type = rewriter.getI1Type();
1602  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1603  op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1604  return success();
1605  }
1606 };
1607 
1608 } // namespace
1609 
1610 //===----------------------------------------------------------------------===//
1611 // Public method for populating conversion rules.
1612 //===----------------------------------------------------------------------===//
1613 
1614 /// Populates the given patterns list with conversion rules required for
1615 /// the sparsification of linear algebra operations.
1617  const TypeConverter &typeConverter, RewritePatternSet &patterns,
1618  bool createSparseDeallocs, bool enableBufferInitialization) {
1619  patterns.add<
1620  SparseAssembleOpConverter, SparseDisassembleOpConverter,
1621  SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1622  SparseCastConverter, SparseExtractSliceConverter,
1623  SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1624  SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1625  SparseSliceGetterOpConverter<ToSliceOffsetOp,
1626  StorageSpecifierKind::DimOffset>,
1627  SparseSliceGetterOpConverter<ToSliceStrideOp,
1628  StorageSpecifierKind::DimStride>,
1629  SparseToPositionsConverter, SparseToCoordinatesConverter,
1630  SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1631  SparseConvertConverter, SparseNewConverter,
1632  SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1633  typeConverter, patterns.getContext());
1634  patterns.add<SparseTensorDeallocConverter>(
1635  typeConverter, patterns.getContext(), createSparseDeallocs);
1636  patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1637  typeConverter, patterns.getContext(), enableBufferInitialization);
1638 }
static Value getSingleValue(ValueRange values)
Assert that the given value range contains a single value and return it.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A helper class to simplify lowering operations with/without function calls.
Definition: CodegenUtils.h:78
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
bool isCompressedLT(LevelType lt)
Definition: Enums.h:415
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition: Enums.h:418
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326