MLIR  21.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library (other than for file I/O), and providing many more
14 // opportunities for subsequent compiler optimization of the generated code.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "Utils/CodegenUtils.h"
20 
32 
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Flatten the given value ranges into a single vector of values.
44  SmallVector<Value> result;
45  for (const auto &vals : values)
46  llvm::append_range(result, vals);
47  return result;
48 }
49 
50 /// Generates a load with proper `index` typing.
51 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
52  idx = genCast(builder, loc, idx, builder.getIndexType());
53  return builder.create<memref::LoadOp>(loc, mem, idx);
54 }
55 
56 /// Generates a store with proper `index` typing and proper value.
57 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
58  Value idx) {
59  idx = genCast(builder, loc, idx, builder.getIndexType());
60  val = genCast(builder, loc, val,
61  cast<ShapedType>(mem.getType()).getElementType());
62  builder.create<memref::StoreOp>(loc, val, mem, idx);
63 }
64 
65 /// Creates a straightforward counting for-loop.
66 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
68  Value lower = Value()) {
69  Type indexType = builder.getIndexType();
70  if (!lower)
71  lower = constantZero(builder, loc, indexType);
72  Value one = constantOne(builder, loc, indexType);
73  scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
74  for (unsigned i = 0, e = fields.size(); i < e; i++)
75  fields[i] = forOp.getRegionIterArg(i);
76  builder.setInsertionPointToStart(forOp.getBody());
77  return forOp;
78 }
79 
80 /// Creates a push back operation.
81 static void createPushback(OpBuilder &builder, Location loc,
83  SparseTensorFieldKind kind, std::optional<Level> lvl,
84  Value value, Value repeat = Value()) {
85  Type etp = desc.getMemRefElementType(kind, lvl);
86  Value field = desc.getMemRefField(kind, lvl);
87  StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
88 
89  auto pushBackOp = builder.create<PushBackOp>(
90  loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
91  genCast(builder, loc, value, etp), repeat);
92 
93  desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
94  desc.setSpecifierField(builder, loc, specFieldKind, lvl,
95  pushBackOp.getNewSize());
96 }
97 
98 /// Generates code that allocates a sparse storage scheme for given rank.
99 static void allocSchemeForRank(OpBuilder &builder, Location loc,
100  MutSparseTensorDescriptor desc, Level startLvl) {
101  const SparseTensorType stt(desc.getRankedTensorType());
102  Value linear = constantIndex(builder, loc, 1);
103  const Level lvlRank = stt.getLvlRank();
104  for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
105  const auto lt = stt.getLvlType(lvl);
106  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
107  // Append linear x positions, initialized to zero. Since each compressed
108  // dimension initially already has a single zero entry, this maintains
109  // the desired "linear + 1" length property at all times. For loose
110  // compression, we multiply linear by two in order to append both the
111  // lo/hi positions.
112  Value posZero = constantZero(builder, loc, stt.getPosType());
113  if (isLooseCompressedLT(lt)) {
114  Value two = constantIndex(builder, loc, 2);
115  linear = builder.create<arith::MulIOp>(loc, linear, two);
116  }
117  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
118  /*value=*/posZero, /*repeat=*/linear);
119  return;
120  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
121  return; // nothing to do
122  }
123  // Keep compounding the size, but nothing needs to be initialized
124  // at this level. We will eventually reach a compressed level or
125  // otherwise the values array for the from-here "all-dense" case.
126  assert(isDenseLT(lt));
127  Value size = desc.getLvlSize(builder, loc, lvl);
128  linear = builder.create<arith::MulIOp>(loc, linear, size);
129  }
130  // Reached values array so prepare for an insertion.
131  Value valZero = constantZero(builder, loc, stt.getElementType());
133  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
134 }
135 
136 /// Creates allocation operation.
138  MemRefType memRefType, Value sz,
139  bool enableInit) {
140  Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
141  Type elemType = memRefType.getElementType();
142  if (enableInit) {
143  Value fillValue = constantZero(builder, loc, elemType);
144  builder.create<linalg::FillOp>(loc, fillValue, buffer);
145  }
146  return buffer;
147 }
148 
149 /// Creates the dim sizes array, filling in from dynamic sizes.
150 static void createDimSizes(OpBuilder &builder, Location loc,
151  SparseTensorType stt, ValueRange dynSizes,
152  /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
153  const Dimension dimRank = stt.getDimRank();
154  dimSizesValues.clear();
155  dimSizesValues.reserve(dimRank);
156  unsigned i = 0;
157  for (const Size sz : stt.getDimShape())
158  dimSizesValues.push_back(ShapedType::isDynamic(sz)
159  ? dynSizes[i++]
160  : constantIndex(builder, loc, sz));
161 }
162 
163 /// Creates allocation for each field in sparse tensor type. Note that
164 /// for all dynamic memrefs in the sparse tensor stroage layout, the
165 /// memory size is really the capacity of the "vector", while the actual
166 /// size resides in the sizes array.
167 static void createAllocFields(OpBuilder &builder, Location loc,
168  SparseTensorType stt, bool enableInit,
169  Value sizeHint,
170  SmallVectorImpl<Value> &lvlSizesValues,
171  /*out*/ SmallVectorImpl<Value> &fields) {
172  Level lvlRank = stt.getLvlRank();
173  // Set up some heuristic sizes. We try to set the initial
174  // size based on available information. Otherwise we just
175  // initialize a few elements to start the reallocation chain.
176  // TODO: refine this
177  Value posHeuristic, crdHeuristic, valHeuristic;
178  if (stt.isAllDense()) {
179  valHeuristic = lvlSizesValues[0];
180  for (Level lvl = 1; lvl < lvlRank; lvl++)
181  valHeuristic =
182  builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
183  } else if (sizeHint) {
184  if (stt.getAoSCOOStart() == 0) {
185  posHeuristic = constantIndex(builder, loc, 2);
186  crdHeuristic = builder.create<arith::MulIOp>(
187  loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
188  } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
189  posHeuristic = builder.create<arith::AddIOp>(
190  loc, sizeHint, constantIndex(builder, loc, 1));
191  crdHeuristic = sizeHint;
192  } else {
193  posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
194  }
195  valHeuristic = sizeHint;
196  } else {
197  posHeuristic = crdHeuristic = valHeuristic =
198  constantIndex(builder, loc, 16);
199  }
200  // Initializes all fields. An initial storage specifier and allocated
201  // positions/coordinates/values memrefs (with heuristic capacity).
203  stt,
204  [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
205  enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
206  Level /*lvl*/, LevelType /*lt*/) -> bool {
207  assert(fields.size() == fIdx);
208  Value field;
209  switch (fKind) {
211  field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
212  break;
214  field = createAllocation(builder, loc, cast<MemRefType>(fType),
215  posHeuristic, enableInit);
216  break;
218  field = createAllocation(builder, loc, cast<MemRefType>(fType),
219  crdHeuristic, enableInit);
220  break;
222  field = createAllocation(builder, loc, cast<MemRefType>(fType),
223  valHeuristic, enableInit);
224  break;
225  }
226  assert(field);
227  fields.push_back(field);
228  // Returns true to continue the iteration.
229  return true;
230  });
231  // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
232  // and gives all position fields an initial zero entry, so that it is
233  // easier to maintain the "linear + 1" length property.
234  MutSparseTensorDescriptor desc(stt, fields);
235  Value posZero = constantZero(builder, loc, stt.getPosType());
236  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
237  desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
238  const auto lt = stt.getLvlType(lvl);
239  if (isCompressedLT(lt) || isLooseCompressedLT(lt))
240  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
241  /*value=*/posZero);
242  }
243  allocSchemeForRank(builder, loc, desc, /*rank=*/0);
244 }
245 
246 /// Helper method that generates block specific to compressed case:
247 ///
248 /// // given: parentPos = posCursor[lvl-1]
249 /// pstart = desc.positions[lvl][parentPos]
250 /// pstop = desc.positions[lvl][parentPos+1]
251 /// plast = pstop - 1
252 /// msz = desc.coordinates[lvl].size()
253 /// if (pstart < pstop) {
254 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
255 /// } else { // first insertion
256 /// isPresent = false
257 /// desc.positions[lvl][parentPos] = msz
258 /// }
259 /// if (isPresent) { // coordinate is already present
260 /// pnext = plast
261 /// } else {
262 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
263 /// desc.positions[lvl][parentPos+1] = msz+1
264 /// pnext = msz
265 /// <prepare level lvl+1>
266 /// }
267 /// posCursor[lvl] = pnext
268 static Value genCompressed(OpBuilder &builder, Location loc,
269  MutSparseTensorDescriptor desc, ValueRange lvlCoords,
270  Value /*unused*/, Value parentPos, Level lvl) {
271  const SparseTensorType stt(desc.getRankedTensorType());
272  const Level lvlRank = stt.getLvlRank();
273  assert(lvl < lvlRank && "Level is out of bounds");
274  assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
275  "Level-rank mismatch");
276  SmallVector<Type> types;
277  Type indexType = builder.getIndexType();
278  Type boolType = builder.getIntegerType(1);
279  unsigned crdFidx;
280  unsigned crdStride;
281  std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
282  const Value one = constantIndex(builder, loc, 1);
283  const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
284  const Value positionsAtLvl = desc.getPosMemRef(lvl);
285  const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
286  const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
287  const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
288  const Value crdStrideC =
289  crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
290  const Value msz =
291  crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
292  : crdMsz;
293  const Value plast = builder.create<arith::SubIOp>(
294  loc, genCast(builder, loc, pstop, indexType), one);
295  // Conditional expression.
296  Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
297  pstart, pstop);
298  types.push_back(boolType);
299  scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
300  types.pop_back();
301  builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
302  Value crd =
303  genLoad(builder, loc, desc.getMemRefField(crdFidx),
304  crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
305  : plast);
306  Value eq = builder.create<arith::CmpIOp>(
307  loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
308  lvlCoords[lvl]);
309  builder.create<scf::YieldOp>(loc, eq);
310  builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
311  if (lvl > 0)
312  genStore(builder, loc, msz, positionsAtLvl, parentPos);
313  builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
314  builder.setInsertionPointAfter(ifOp1);
315  // If present construct. Note that for a non-unique dimension level, we
316  // simply set the condition to false and rely on CSE/DCE to clean up the IR.
317  //
318  // TODO: generate less temporary IR?
319  //
320  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
321  types.push_back(desc.getField(i).getType());
322  types.push_back(indexType);
323  const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
324  : constantI1(builder, loc, false);
325  scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
326  // If present (fields unaffected, update pnext to plast).
327  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
328 
329  // FIXME: This does not looks like a clean way, but probably the most
330  // efficient way.
331  desc.getFields().push_back(plast);
332  builder.create<scf::YieldOp>(loc, desc.getFields());
333  desc.getFields().pop_back();
334 
335  // If !present (changes fields, update pnext).
336  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
337  Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
338  genStore(builder, loc, mszp1, positionsAtLvl, pp1);
339  createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
340  /*value=*/lvlCoords[lvl]);
341  // Prepare the next level "as needed".
342  if ((lvl + 1) < lvlRank)
343  allocSchemeForRank(builder, loc, desc, lvl + 1);
344 
345  desc.getFields().push_back(msz);
346  builder.create<scf::YieldOp>(loc, desc.getFields());
347  desc.getFields().pop_back();
348 
349  // Update fields and return next pos.
350  builder.setInsertionPointAfter(ifOp2);
351  unsigned o = 0;
352  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
353  desc.setField(i, ifOp2.getResult(o++));
354  return ifOp2.getResult(o);
355 }
356 
357 /// Generates insertion finalization code.
358 static void genEndInsert(OpBuilder &builder, Location loc,
359  SparseTensorDescriptor desc) {
360  const SparseTensorType stt(desc.getRankedTensorType());
361  const Level lvlRank = stt.getLvlRank();
362  for (Level lvl = 0; lvl < lvlRank; lvl++) {
363  const auto lt = stt.getLvlType(lvl);
364  if (isCompressedLT(lt)) {
365  // Compressed dimensions need a position cleanup for all entries
366  // that were not visited during the insertion pass.
367  //
368  // TODO: avoid cleanup and keep compressed scheme consistent at all
369  // times?
370  //
371  if (lvl > 0) {
372  Type posType = stt.getPosType();
373  Value posMemRef = desc.getPosMemRef(lvl);
374  Value hi = desc.getPosMemSize(builder, loc, lvl);
375  Value zero = constantIndex(builder, loc, 0);
376  Value one = constantIndex(builder, loc, 1);
377  // Vector of only one, but needed by createFor's prototype.
378  SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
379  scf::ForOp loop = createFor(builder, loc, hi, inits, one);
380  Value i = loop.getInductionVar();
381  Value oldv = loop.getRegionIterArg(0);
382  Value newv = genLoad(builder, loc, posMemRef, i);
383  Value posZero = constantZero(builder, loc, posType);
384  Value cond = builder.create<arith::CmpIOp>(
385  loc, arith::CmpIPredicate::eq, newv, posZero);
386  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
387  cond, /*else*/ true);
388  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
389  genStore(builder, loc, oldv, posMemRef, i);
390  builder.create<scf::YieldOp>(loc, oldv);
391  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
392  builder.create<scf::YieldOp>(loc, newv);
393  builder.setInsertionPointAfter(ifOp);
394  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
395  builder.setInsertionPointAfter(loop);
396  }
397  } else {
398  assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
399  isNOutOfMLT(lt));
400  }
401  }
402 }
403 
404 /// Generates a subview into the sizes.
405 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
406  Value sz) {
407  auto memTp = llvm::cast<MemRefType>(mem.getType());
408  // For higher-dimensional memrefs, we assume that the innermost
409  // dimension is always of the right size.
410  // TODO: generate complex truncating view here too?
411  if (memTp.getRank() > 1)
412  return mem;
413  // Truncate linear memrefs to given size.
414  return builder
415  .create<memref::SubViewOp>(
416  loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
417  mem, ValueRange{}, ValueRange{sz}, ValueRange{},
418  ArrayRef<int64_t>{0}, // static offset
419  ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
420  ArrayRef<int64_t>{1}) // static stride
421  .getResult();
422 }
423 
424 /// Creates the reassociation array.
426 getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
427  SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
428  // Create reassociation in the form:
429  // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
430  for (unsigned i = 0; i < batchLvls; i++)
431  ret[i].push_back(i);
432 
433  for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
434  ret.back().push_back(i);
435 
436  return ret;
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // Codegen rules.
441 //===----------------------------------------------------------------------===//
442 
443 namespace {
444 
445 /// Helper class to help lowering sparse_tensor.insert operation.
446 class SparseInsertGenerator
447  : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
448 public:
449  SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
450  bool genCall)
451  : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
452 
453  /// Generates code along an insertion path without the need for a "cursor".
454  /// This current insertion strategy comes at the expense of some testing
455  /// overhead for each insertion. The strategy will be optimized later for
456  /// common insertion patterns. The current insertion strategy also assumes
457  /// insertions occur in "a reasonable order" that enables building the
458  /// storage scheme in an appending/inserting kind of fashion (i.e. no
459  /// in-between insertions that need data movement). The implementation
460  /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
461  ///
462  /// TODO: better unord/not-unique; also generalize, optimize, specialize!
463  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
464  OpBuilder &builder, Location loc) {
465  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
466  const Level lvlRank = stt.getLvlRank();
467  // Extract fields and coordinates from args.
468  SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
469  MutSparseTensorDescriptor desc(stt, fields);
470  const SmallVector<Value> coords =
471  llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
472  Value value = args.back();
473  Value parentPos = constantZero(builder, loc, builder.getIndexType());
474  // Generate code for every level.
475  for (Level lvl = 0; lvl < lvlRank; lvl++) {
476  const auto lt = stt.getLvlType(lvl);
477  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
478  // Create:
479  // if (!present) {
480  // coordinates[lvl].push_back(coords[lvl])
481  // <update positions and prepare level lvl + 1>
482  // }
483  // positions[lvl] = coordinates.size() - 1
484  // <insert @ positions[lvl] at next level lvl + 1>
485  if (isLooseCompressedLT(lt)) {
486  Value two = constantIndex(builder, loc, 2);
487  parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
488  }
489  parentPos =
490  genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
491  } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
492  // Create:
493  // coordinates[lvl].push_back(coords[lvl])
494  // positions[lvl] = positions[lvl-1]
495  // <insert @ positions[lvl] at next level lvl + 1>
497  lvl, /*value=*/coords[lvl]);
498  } else {
499  assert(isDenseLT(lt));
500  // Construct the new position as:
501  // positions[lvl] = size * positions[lvl-1] + coords[lvl]
502  // <insert @ positions[lvl] at next level lvl + 1>
503  Value size = desc.getLvlSize(builder, loc, lvl);
504  Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
505  parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
506  }
507  }
508  // Reached the actual value append/insert.
509  if (!stt.isDenseLvl(lvlRank - 1))
511  std::nullopt, value);
512  else
513  genStore(builder, loc, value, desc.getValMemRef(), parentPos);
514  return fields;
515  }
516 
517  std::string getMangledFuncName() {
518  // The mangled name of the function has this format:
519  // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
520  constexpr const char kInsertFuncNamePrefix[] = "_insert_";
521  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
522  SmallString<32> nameBuffer;
523  llvm::raw_svector_ostream nameOstream(nameBuffer);
524  nameOstream << kInsertFuncNamePrefix;
525  const Level lvlRank = stt.getLvlRank();
526  for (Level l = 0; l < lvlRank; l++) {
527  std::string lvlType = toMLIRString(stt.getLvlType(l));
528  // Replace/remove punctuations in level properties.
529  std::replace_if(
530  lvlType.begin(), lvlType.end(),
531  [](char c) { return c == '(' || c == ','; }, '_');
532  llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
533  nameOstream << lvlType << "_";
534  }
535  // Static dim sizes are used in the generated code while dynamic sizes are
536  // loaded from the dimSizes buffer. This is the reason for adding the shape
537  // to the function name.
538  for (const auto sz : stt.getDimShape())
539  nameOstream << sz << "_";
540  // Permutation information is also used in generating insertion.
541  if (!stt.isIdentity())
542  nameOstream << stt.getDimToLvl() << "_";
543  nameOstream << stt.getElementType() << "_";
544  nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
545  return nameOstream.str().str();
546  }
547 
548 private:
549  TensorType rtp;
550 };
551 
552 /// Sparse tensor storage conversion rule for returns.
553 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
554 public:
556  LogicalResult
557  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
558  ConversionPatternRewriter &rewriter) const override {
559  // Create a return with the flattened value extracted from sparse tensors.
560  rewriter.replaceOpWithNewOp<func::ReturnOp>(
561  op, flattenValues(adaptor.getOperands()));
562  return success();
563  }
564 };
565 
566 /// Sparse tensor storage conversion rule for calls.
567 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
568 public:
569  // The default CallOp converter can not handle 1:N type conversion.
571  LogicalResult
572  matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
573  ConversionPatternRewriter &rewriter) const override {
574  Location loc = op.getLoc();
575  // In case of:
576  // sparse_tensor, f, sparse_tensor = call @foo(...)
577  // ==>
578  // memref..., f, memref = call @foo(...) replace with
579  // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
580  SmallVector<Type> finalRetTy;
581  if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
582  return failure();
583 
584  // (1) Generates new call with flattened return value.
585  auto newCall = rewriter.create<func::CallOp>(
586  loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
587  // (2) Gather sparse tensor returns.
588  SmallVector<SmallVector<Value>> packedResultVals;
589  // Tracks the offset of current return value (of the original call)
590  // relative to the new call (after sparse tensor flattening);
591  unsigned retOffset = 0;
592  // Temporal buffer to hold the flattened list of type for
593  // a sparse tensor.
594  SmallVector<Type> sparseFlat;
595  for (auto ret : op.getResults()) {
596  assert(retOffset < newCall.getNumResults());
597  auto retType = ret.getType();
598  if (failed(typeConverter->convertType(retType, sparseFlat)))
599  llvm_unreachable("Failed to convert type in sparse tensor codegen");
600 
601  // Converted types can not be empty when the type conversion succeed.
602  assert(!sparseFlat.empty());
603  if (sparseFlat.size() > 1) {
604  auto flatSize = sparseFlat.size();
605  packedResultVals.emplace_back();
606  llvm::append_range(packedResultVals.back(),
607  newCall.getResults().slice(retOffset, flatSize));
608  retOffset += flatSize;
609  } else {
610  // If this is an 1:1 conversion, no need for casting.
611  packedResultVals.emplace_back();
612  packedResultVals.back().push_back(newCall.getResult(retOffset));
613  retOffset++;
614  }
615  sparseFlat.clear();
616  }
617 
618  assert(packedResultVals.size() == op.getNumResults());
619  rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
620  return success();
621  }
622 };
623 
624 /// Sparse codegen rule for level accesses.
625 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
626 public:
628  LogicalResult
629  matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
630  ConversionPatternRewriter &rewriter) const override {
631  std::optional<int64_t> lvl = op.getConstantLvlIndex();
632  RankedTensorType srcType = op.getSource().getType();
633  if (!lvl || !getSparseTensorEncoding(srcType))
634  return failure();
635 
636  auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType);
637  auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
638 
639  rewriter.replaceOp(op, sz);
640  return success();
641  }
642 };
643 
644 // TODO: use a new SortCOO operation here instead of reusing convert op.
645 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
647  LogicalResult
648  matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
649  ConversionPatternRewriter &rewriter) const override {
650  Location loc = op.getLoc();
651  MLIRContext *ctx = op.getContext();
652 
653  SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
654  SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
655 
656  // Should have been verified.
657  assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
658  dstStt.isCOOType() && srcStt.isCOOType());
659  assert(dstStt.hasSameDimToLvl(srcStt));
660 
661  // We don't need a mutable descriptor here as we perform sorting in-place.
662  auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(),
663  op.getInputCoo().getType());
664  auto nnz = desc.getValMemSize(rewriter, op.getLoc());
665  auto crd = desc.getAOSMemRef();
666  auto val = desc.getValMemRef();
667 
668  // Otherwise we need another data shuffle and a non-identity map.
669  assert(dstStt.hasSameDimToLvl(srcStt));
670  (void)dstStt; // to silence warning when assertion is disabled
671 
672  auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
673 
674  rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
675  rewriter.getIndexAttr(0), op.getAlgorithm());
676 
677  // Since we do in-place sorting, the destinate tensor will have the same set
678  // of memrefs as the source tensor.
679  rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
680  return success();
681  }
682 };
683 
684 template <typename Op, StorageSpecifierKind kind>
685 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
686 public:
689 
690  LogicalResult
691  matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
692  ConversionPatternRewriter &rewriter) const override {
693  // Simply lowers to specifer.get <field> operation.
694  auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
695  op.getSlice().getType());
696  auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
697  op.getDim().getZExtValue());
698 
699  rewriter.replaceOp(op, v);
700  return success();
701  }
702 };
703 
704 /// Sparse codegen rule for trivial tensor casts.
705 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
706 public:
708  LogicalResult
709  matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
710  ConversionPatternRewriter &rewriter) const override {
711  // Only rewrite identically annotated source/dest.
712  auto encDst = getSparseTensorEncoding(op.getType());
713  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
714  if (!encDst || encDst != encSrc)
715  return failure();
716  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
717  return success();
718  }
719 };
720 
721 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
722 public:
724  LogicalResult
725  matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
726  ConversionPatternRewriter &rewriter) const override {
727  // Simply fold the operation.
728  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
729  return success();
730  }
731 };
732 
733 /// Sparse codegen rule for the alloc operator.
734 class SparseTensorAllocConverter
735  : public OpConversionPattern<bufferization::AllocTensorOp> {
736 public:
738  SparseTensorAllocConverter(const TypeConverter &typeConverter,
739  MLIRContext *context, bool enableInit)
740  : OpConversionPattern(typeConverter, context),
741  enableBufferInitialization(enableInit) {}
742 
743  LogicalResult
744  matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
745  ConversionPatternRewriter &rewriter) const override {
746  const auto resType = getSparseTensorType(op);
747  if (!resType.hasEncoding())
748  return failure();
749 
750  Location loc = op.getLoc();
751  // Deal with copy.
752  if (op.getCopy()) {
753  auto desc = getDescriptorFromTensorTuple(
754  adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
755  SmallVector<Value> fields;
756  fields.reserve(desc.getNumFields());
757  // Memcpy on memref fields.
758  for (auto field : desc.getMemRefFields()) {
759  auto memrefTp = cast<MemRefType>(field.getType());
760  auto size = rewriter.create<memref::DimOp>(loc, field, 0);
761  auto copied =
762  rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
763  rewriter.create<memref::CopyOp>(loc, field, copied);
764  fields.push_back(copied);
765  }
766  // Reuses specifier.
767  fields.push_back(desc.getSpecifier());
768  assert(fields.size() == desc.getNumFields());
769  rewriter.replaceOpWithMultiple(op, {fields});
770  return success();
771  }
772 
773  if (!resType.isIdentity()) {
774  return rewriter.notifyMatchFailure(
775  op, "try run --sparse-reinterpret-map before codegen");
776  }
777  // Level size equals to dimension size since lvl2dim map is an identity map.
778  SmallVector<Value> lvlSizesValues;
779  createDimSizes(rewriter, loc, resType,
780  flattenValues(adaptor.getDynamicSizes()),
781  /*dimSizesValues=*/lvlSizesValues);
782 
783  // Construct allocation for each field.
784  Value sizeHint = op.getSizeHint();
785  SmallVector<Value> fields;
786  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
787  sizeHint, lvlSizesValues, fields);
788 
789  // Replace operation with resulting memrefs.
790  rewriter.replaceOpWithMultiple(op, {fields});
791  return success();
792  }
793 
794 private:
795  bool enableBufferInitialization;
796 };
797 
798 /// Sparse codegen rule for the empty tensor operator.
799 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
800 public:
802  SparseTensorEmptyConverter(const TypeConverter &typeConverter,
803  MLIRContext *context, bool enableInit)
804  : OpConversionPattern(typeConverter, context),
805  enableBufferInitialization(enableInit) {}
806 
807  LogicalResult
808  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
809  ConversionPatternRewriter &rewriter) const override {
810  const auto resType = getSparseTensorType(op);
811  if (!resType.hasEncoding())
812  return failure();
813 
814  if (!resType.isIdentity()) {
815  return rewriter.notifyMatchFailure(
816  op, "try run --sparse-reinterpret-map before codegen");
817  }
818 
819  Location loc = op.getLoc();
820  // Level size equals to dimension size since lvl2dim map is an identity map.
821  SmallVector<Value> lvlSizesValues;
822  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
823  /*dimSizesValues=*/lvlSizesValues);
824  // Construct allocation for each field.
825  Value sizeHint; // none
826  SmallVector<Value> fields;
827  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
828  sizeHint, lvlSizesValues, fields);
829 
830  // Replace operation with resulting memrefs.
831  rewriter.replaceOpWithMultiple(op, {fields});
832  return success();
833  }
834 
835 private:
836  bool enableBufferInitialization;
837 };
838 
839 /// Sparse codegen rule for the dealloc operator.
840 class SparseTensorDeallocConverter
841  : public OpConversionPattern<bufferization::DeallocTensorOp> {
842 public:
844  SparseTensorDeallocConverter(const TypeConverter &typeConverter,
845  MLIRContext *context, bool createDeallocs)
846  : OpConversionPattern(typeConverter, context),
847  createDeallocs(createDeallocs) {}
848 
849  LogicalResult
850  matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
851  ConversionPatternRewriter &rewriter) const override {
852  auto enc = getSparseTensorEncoding(op.getTensor().getType());
853  if (!enc)
854  return failure();
855 
856  // If user requests not to deallocate sparse tensors, simply erase the
857  // operation.
858  if (createDeallocs) {
859  // Replace the sparse tensor deallocation with field deallocations.
860  Location loc = op.getLoc();
861  auto desc = getDescriptorFromTensorTuple(
862  adaptor.getTensor(),
863  cast<RankedTensorType>(op.getTensor().getType()));
864  for (auto input : desc.getMemRefFields())
865  // Deallocate every buffer used to store the sparse tensor handler.
866  rewriter.create<memref::DeallocOp>(loc, input);
867  }
868  rewriter.eraseOp(op);
869  return success();
870  }
871 
872 private:
873  const bool createDeallocs;
874 };
875 
876 /// Sparse codegen rule for tensor rematerialization.
877 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
878 public:
880  LogicalResult
881  matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
882  ConversionPatternRewriter &rewriter) const override {
883  // Prepare descriptor.
884  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
885  op.getTensor().getType());
886  // Generate optional insertion finalization code.
887  if (op.getHasInserts())
888  genEndInsert(rewriter, op.getLoc(), desc);
889  // Replace operation with resulting memrefs.
890  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
891  return success();
892  }
893 };
894 
895 /// Sparse codegen rule for the expand op.
896 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
897 public:
899  LogicalResult
900  matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
901  ConversionPatternRewriter &rewriter) const override {
902  if (!getSparseTensorEncoding(op.getTensor().getType()))
903  return failure();
904  Location loc = op->getLoc();
905  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
906  op.getTensor().getType());
907  const auto srcType = getSparseTensorType(op.getTensor());
908  Type eltType = srcType.getElementType();
909  Type boolType = rewriter.getIntegerType(1);
910  Type idxType = rewriter.getIndexType();
911  // All initialization should be done on entry of the loop nest.
912  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
913 
914  // Determine the size for access expansion (always the innermost stored
915  // level size).
916  const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
917  // Generate a memref for `sz` elements of type `t`.
918  const auto genAlloc = [&](Type t) {
919  const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
920  return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
921  };
922  // Allocate temporary buffers for values/filled-switch and added.
923  // We do not use stack buffers for this, since the expanded size may
924  // be rather large (as it envelops a single expanded dense dimension).
925  Value values = genAlloc(eltType);
926  Value filled = genAlloc(boolType);
927  Value added = genAlloc(idxType);
928  Value zero = constantZero(rewriter, loc, idxType);
929  // Reset the values/filled-switch to all-zero/false. Note that this
930  // introduces an O(N) operation into the computation, but this reset
931  // operation is amortized over the innermost loops for the access
932  // pattern expansion. As noted in the operation doc, we would like
933  // to amortize this setup cost even between kernels.
934  rewriter.create<linalg::FillOp>(
935  loc, ValueRange{constantZero(rewriter, loc, eltType)},
936  ValueRange{values});
937  rewriter.create<linalg::FillOp>(
938  loc, ValueRange{constantZero(rewriter, loc, boolType)},
939  ValueRange{filled});
940  // Replace expansion op with these buffers and initial coordinate.
941  assert(op.getNumResults() == 4);
942  rewriter.replaceOp(op, {values, filled, added, zero});
943  return success();
944  }
945 };
946 
947 /// Sparse codegen rule for the compress operator.
948 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
949 public:
951  LogicalResult
952  matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
953  ConversionPatternRewriter &rewriter) const override {
954  Location loc = op->getLoc();
955  SmallVector<Value> fields;
956  auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
957  op.getTensor().getType());
958  Value values = llvm::getSingleElement(adaptor.getValues());
959  Value filled = llvm::getSingleElement(adaptor.getFilled());
960  Value added = llvm::getSingleElement(adaptor.getAdded());
961  Value count = llvm::getSingleElement(adaptor.getCount());
962  const SparseTensorType dstType(desc.getRankedTensorType());
963  Type eltType = dstType.getElementType();
964 
965  // If the innermost level is ordered, we need to sort the coordinates
966  // in the "added" array prior to applying the compression.
967  if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
968  rewriter.create<SortOp>(
969  loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
970  rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
971  // While performing the insertions, we also need to reset the elements
972  // of the values/filled-switch by only iterating over the set elements,
973  // to ensure that the runtime complexity remains proportional to the
974  // sparsity of the expanded access pattern.
975  //
976  // Generate
977  // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
978  // crd = added[i];
979  // value = values[crd];
980  // insert({lvlCoords, crd}, value);
981  // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
982  // values[crd] = 0;
983  // filled[crd] = false;
984  // yield new_memrefs
985  // }
986  scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
987  Value i = loop.getInductionVar();
988 
989  Value crd = genLoad(rewriter, loc, added, i);
990  Value value = genLoad(rewriter, loc, values, crd);
991  SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
992  SmallVector<Type> flatSpTensorTps = llvm::to_vector(
993  llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
994  SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
995  params.append(flatLvlCoords.begin(), flatLvlCoords.end());
996  params.push_back(crd);
997  params.push_back(value);
998  SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
999  params, /*genCall=*/true);
1000  SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1001  genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
1002  genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
1003  rewriter.create<scf::YieldOp>(loc, insertRet);
1004 
1005  rewriter.setInsertionPointAfter(loop);
1006  // Deallocate the buffers on exit of the full loop nest.
1007  Operation *parent = getTop(op);
1008  rewriter.setInsertionPointAfter(parent);
1009  rewriter.create<memref::DeallocOp>(loc, values);
1010  rewriter.create<memref::DeallocOp>(loc, filled);
1011  rewriter.create<memref::DeallocOp>(loc, added);
1012  // Replace operation with resulting memrefs.
1013  rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1014  return success();
1015  }
1016 };
1017 
1018 /// Sparse codegen rule for the insert operator.
1019 class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1020 public:
1022  LogicalResult
1023  matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1024  ConversionPatternRewriter &rewriter) const override {
1025  auto stt = getSparseTensorType(op.getDest());
1026  if (!stt.hasEncoding())
1027  return failure();
1028  assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1029 
1030  Location loc = op.getLoc();
1031  auto desc =
1032  getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
1033  TypeRange flatSpTensorTps = desc.getFields().getTypes();
1034  SmallVector<Value> params = llvm::to_vector(desc.getFields());
1035  SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
1036  params.append(flatIndices.begin(), flatIndices.end());
1037  params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1038  SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1039  params, /*genCall=*/true);
1040  SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1041  // Replace operation with resulting memrefs.
1042  rewriter.replaceOpWithMultiple(op, {ret});
1043  return success();
1044  }
1045 };
1046 
1047 /// Sparse codegen rule for position accesses.
1048 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1049 public:
1050  using OpAdaptor = typename ToPositionsOp::Adaptor;
1052  LogicalResult
1053  matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1054  ConversionPatternRewriter &rewriter) const override {
1055  // Replace the requested position access with corresponding field.
1056  // The view is restricted to the actual size to ensure clients
1057  // of this operation truly observe size, not capacity!
1058  Location loc = op.getLoc();
1059  Level lvl = op.getLevel();
1060  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1061  op.getTensor().getType());
1062  auto mem = desc.getPosMemRef(lvl);
1063  auto size = desc.getPosMemSize(rewriter, loc, lvl);
1064  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1065  return success();
1066  }
1067 };
1068 
1069 /// Sparse codegen rule for accessing the coordinates arrays.
1070 class SparseToCoordinatesConverter
1071  : public OpConversionPattern<ToCoordinatesOp> {
1072 public:
1073  using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1075  LogicalResult
1076  matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1077  ConversionPatternRewriter &rewriter) const override {
1078  // Replace the requested coordinates access with corresponding field.
1079  // The view is restricted to the actual size to ensure clients
1080  // of this operation truly observe size, not capacity!
1081  Location loc = op.getLoc();
1082  Level lvl = op.getLevel();
1083  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1084  op.getTensor().getType());
1085  auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1086  if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
1087  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1088  mem = genSliceToSize(rewriter, loc, mem, size);
1089  }
1090  rewriter.replaceOp(op, mem);
1091  return success();
1092  }
1093 };
1094 
1095 /// Sparse codegen rule for accessing the linear coordinates buffer.
1096 class SparseToCoordinatesBufferConverter
1097  : public OpConversionPattern<ToCoordinatesBufferOp> {
1098 public:
1099  using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1101  LogicalResult
1102  matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1103  ConversionPatternRewriter &rewriter) const override {
1104  // Replace the requested coordinates access with corresponding field.
1105  // The view is restricted to the actual size to ensure clients
1106  // of this operation truly observe size, not capacity!
1107  Location loc = op.getLoc();
1108  Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
1109  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1110  op.getTensor().getType());
1111  auto mem = desc.getAOSMemRef();
1112  auto size = desc.getCrdMemSize(rewriter, loc, lvl);
1113  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1114  return success();
1115  }
1116 };
1117 
1118 /// Sparse codegen rule for value accesses.
1119 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1120 public:
1121  using OpAdaptor = typename ToValuesOp::Adaptor;
1123  LogicalResult
1124  matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1125  ConversionPatternRewriter &rewriter) const override {
1126  // Replace the requested values access with corresponding field.
1127  // The view is restricted to the actual size to ensure clients
1128  // of this operation truly observe size, not capacity!
1129  Location loc = op.getLoc();
1130  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1131  op.getTensor().getType());
1132  auto mem = desc.getValMemRef();
1133  auto size = desc.getValMemSize(rewriter, loc);
1134  rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
1135  return success();
1136  }
1137 };
1138 
1139 /// Sparse codegen rule for the convert operator.
1140 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1141 public:
1143  LogicalResult
1144  matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1145  ConversionPatternRewriter &rewriter) const override {
1146  SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1147  SparseTensorEncodingAttr encSrc =
1148  getSparseTensorEncoding(op.getSource().getType());
1149  // The output tensor can not be a slice and those cases should have been
1150  // rejected by ConvertOp::verify() already.
1151  assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1152  // Different encoding (except for different bitwidth) should be handled by
1153  // rewriting.
1154  // We need further rewrites if the input tensor is a slice too.
1155  if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1156  encSrc.isSlice()) {
1157  return failure();
1158  }
1159 
1160  Type retElemTp = op.getResult().getType().getElementType();
1161  Type srcElemTp = op.getSource().getType().getElementType();
1162  // Fold the trivial cases.
1163  if (retElemTp == srcElemTp && encDst == encSrc) {
1164  rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1165  return success();
1166  }
1167  //
1168  // Do element-wise type conversion without using InsertOp.
1169  //
1170  // for each memref in srcTensor:
1171  // dst = memref.alloc
1172  // if srcMemRefType != dstMemRefType:
1173  // for every dst[i] = cast(src[i])
1174  // else:
1175  // dst = memref.copy(src)
1176  Location loc = op.getLoc();
1177  auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
1178  op.getSource().getType());
1179  SmallVector<Value> fields;
1181  SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1182  [&rewriter, &fields, srcDesc,
1183  loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1184  LevelType /*lt*/) -> bool {
1185  // Simply reuses the storage specifier as it is an SSA value.
1186  if (fKind == SparseTensorFieldKind::StorageSpec) {
1187  fields.push_back(srcDesc.getSpecifier());
1188  } else {
1189  // Allocates new memrefs
1190  Value srcMem = srcDesc.getMemRefField(fIdx);
1191  // TODO: We can instead use the actual memSize in specifier, that
1192  // would require a subViewOp to avoid overflow when copying
1193  // values.
1194  Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1195  auto dstMem = rewriter.create<memref::AllocOp>(
1196  loc, cast<MemRefType>(fTp), sz);
1197  if (fTp != srcMem.getType()) {
1198  // Converts elements type.
1199  scf::buildLoopNest(
1200  rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1201  constantIndex(rewriter, loc, 1),
1202  [srcMem, &dstMem](OpBuilder &builder, Location loc,
1203  ValueRange ivs) {
1204  Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1205  Value casted = genCast(builder, loc, v,
1206  dstMem.getType().getElementType());
1207  builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1208  });
1209  } else {
1210  // TODO: We can even reuse the same memref for the new tensor,
1211  // but that requires a `ref-counting` based memory management
1212  // for shared memrefs between multiple sparse tensors.
1213  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1214  }
1215  fields.push_back(dstMem);
1216  }
1217  return true;
1218  });
1219 
1220  rewriter.replaceOpWithMultiple(op, {fields});
1221  return success();
1222  }
1223 };
1224 
1225 class SparseExtractSliceConverter
1226  : public OpConversionPattern<tensor::ExtractSliceOp> {
1227 public:
1229  LogicalResult
1230  matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1231  ConversionPatternRewriter &rewriter) const override {
1232  Location loc = op.getLoc();
1233  MLIRContext *ctx = op.getContext();
1234  auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1235  auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1236  // TODO: We should check these in ExtractSliceOp::verify.
1237  if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1238  return failure();
1239  assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1240 
1241  SmallVector<Value> fields;
1242  auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields,
1243  op.getSource().getType());
1244 
1245  auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1246  loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1247  desc.setSpecifier(newSpec);
1248 
1249  // Fills in slice information.
1250  for (auto [idx, offset, size, stride] : llvm::enumerate(
1251  op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1252  Dimension dim = idx;
1253 
1254  Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1255  Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1256  Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1257  // TODO: We could probably only set dynamic value here. But it would
1258  // requires us to fill the hole when casting a static slice to dynamic
1259  // slice.
1260  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1261  dim, offsetV);
1262 
1263  // FIXME: we need to distinguish level sizes and dimension size for slices
1264  // here. Maybe we should store slice level sizes in a different array
1265  // instead of reusing it.
1266  assert(srcEnc.isIdentity());
1267  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1268  sizeV);
1269  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1270  dim, strideV);
1271  }
1272 
1273  // NOTE: we can not generate tuples directly from descriptor here, as the
1274  // descriptor is holding the original type, yet we want the slice type
1275  // here (they shared every memref but with an updated specifier).
1276  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1277  return success();
1278  }
1279 };
1280 
1281 /// Sparse codegen rule for number of entries operator.
1282 class SparseNumberOfEntriesConverter
1283  : public OpConversionPattern<NumberOfEntriesOp> {
1284 public:
1286  LogicalResult
1287  matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1288  ConversionPatternRewriter &rewriter) const override {
1289  // Query memSizes for the actually stored values.
1290  // FIXME: the nse value computed in this way might be wrong when there is
1291  // any "loose_compressed" level.
1292  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1293  op.getTensor().getType());
1294  rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
1295  return success();
1296  }
1297 };
1298 
1299 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1301  LogicalResult
1302  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1303  ConversionPatternRewriter &rewriter) const override {
1304  Location loc = op.getLoc();
1305  const auto stt = getSparseTensorType(op.getResult());
1306 
1307  SmallVector<Value> fields;
1308 
1310  stt,
1311  [&rewriter, &fields, &op, &stt,
1312  loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1313  Level /*lvl*/, LevelType lt) -> bool {
1314  assert(fields.size() == fIdx);
1315  if (fKind == SparseTensorFieldKind::StorageSpec) {
1316  fields.push_back(
1317  SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1318  } else {
1319  // Else simply takes the inputs.
1320  Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1321  ? op.getValues()
1322  : op.getLevels()[fIdx];
1323  // TODO: handle batch.
1324  TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1325  if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1326  // Flattens the buffer to batchLvlRank.
1327  auto reassoc = getReassociationForFlattening(
1328  mem.getType(), stt.getBatchLvlRank());
1329  mem = rewriter.create<memref::CastOp>(
1330  loc, fType,
1331  rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1332  } else {
1333  mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1334  }
1335  fields.push_back(mem);
1336  }
1337  return true;
1338  });
1339 
1340  MutSparseTensorDescriptor desc(stt, fields);
1341  Value c0 = constantIndex(rewriter, loc, 0);
1342  Value c1 = constantIndex(rewriter, loc, 1);
1343  Value c2 = constantIndex(rewriter, loc, 2);
1344  Value posBack = c0; // index to the last value in the position array
1345  Value memSize = c1; // memory size for current array
1346 
1347  Level trailCOOStart = stt.getAoSCOOStart();
1348  Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1349  // Sets up SparseTensorSpecifier.
1350  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1351  assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1352 
1353  // Sets up the level size.
1354  auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1355  desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1356  // We use a single AOS array to store the trailing COO, so there is only
1357  // one memory size to set for the entire COO section.
1358  if (lvl > trailCOOStart)
1359  continue;
1360 
1361  // Sets up the memory size by reading the last value in position array.
1362  LevelType lt = stt.getLvlType(lvl);
1363  // Simply forwards the position index when this is a dense level.
1364  if (lt.isa<LevelFormat::Dense>()) {
1365  memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1366  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1367  continue;
1368  }
1369  if (lt.isa<LevelFormat::Batch>()) {
1370  // Skips batch levels as it is not linearized.
1371  // FIXME: this assumes that every batch has the same number of nse, need
1372  // to be generalized to handle varied-size batches.
1373  continue;
1374  }
1375 
1376  if (isWithPosLT(lt)) {
1377  assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1378  if (isLooseCompressedLT(lt)) {
1379  memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1380  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1381  } else {
1382  assert(isCompressedLT(lt));
1383  posBack = memSize;
1384  memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1385  }
1386  desc.setPosMemSize(rewriter, loc, lvl, memSize);
1387  // The last value in position array is the memory size for next level.
1388  // FIXME: this assumes that every batch has the same number of nse, need
1389  // to be generalized to handle varied-size batches.
1390  SmallVector<Value> batched(stt.getBatchLvlRank(),
1391  constantIndex(rewriter, loc, 0));
1392  batched.push_back(posBack);
1393  memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
1394  posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1395  }
1396  assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1397  // FIXME: This seems to be unnecessarily complex, can we simplify it?
1398  if (lvl == trailCOOStart) {
1399  Value cooSz = rewriter.create<arith::MulIOp>(
1400  loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1401  desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1402  } else {
1403  desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1404  }
1405  }
1406  desc.setValMemSize(rewriter, loc, memSize);
1407 
1408  rewriter.replaceOpWithMultiple(op, {desc.getFields()});
1409  return success();
1410  }
1411 };
1412 
1413 struct SparseDisassembleOpConverter
1414  : public OpConversionPattern<DisassembleOp> {
1416  SparseDisassembleOpConverter(const TypeConverter &typeConverter,
1417  MLIRContext *context)
1418  : OpConversionPattern(typeConverter, context) {}
1419 
1420  LogicalResult
1421  matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1422  ConversionPatternRewriter &rewriter) const override {
1423  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
1424  op.getTensor().getType());
1425  Location loc = op.getLoc();
1426  SmallVector<Value> retMem;
1427  SmallVector<Value> retLen;
1428  desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1429  &retLen](FieldIndex fid,
1430  SparseTensorFieldKind fKind,
1431  Level lvl, LevelType lt) -> bool {
1432  if (fKind == SparseTensorFieldKind::StorageSpec)
1433  return true;
1435  Value sz, src;
1437  if (fKind == SparseTensorFieldKind::ValMemRef) {
1438  sz = desc.getValMemSize(rewriter, loc);
1439  src = desc.getValMemRef();
1440  dst = genToMemref(rewriter, loc, op.getOutValues());
1441 
1442  retMem.push_back(dst);
1443  Type valLenTp = op.getValLen().getType();
1444  retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
1445  } else {
1446  assert(fKind == SparseTensorFieldKind::PosMemRef ||
1447  fKind == SparseTensorFieldKind::CrdMemRef);
1448 
1449  sz = fKind == SparseTensorFieldKind::PosMemRef
1450  ? desc.getPosMemSize(rewriter, loc, lvl)
1451  : desc.getCrdMemSize(rewriter, loc, lvl);
1452  src = desc.getMemRefField(fid);
1453  dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1454  retMem.push_back(dst);
1455  // Retrieves the corresponding level length type.
1456  Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1457  retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1458  }
1459  Value flatOut = dst;
1460  if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1461  auto reassoc =
1462  getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
1463  flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1464  }
1465  Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1466  Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1467  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1468  return true;
1469  });
1470 
1471  // Converts MemRefs back to Tensors.
1472  SmallVector<Value> retValues = llvm::to_vector(
1473  llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1474  return rewriter.create<bufferization::ToTensorOp>(loc, v);
1475  }));
1476  // Appends the actual memory length used in each buffer returned.
1477  retValues.append(retLen.begin(), retLen.end());
1478  rewriter.replaceOp(op, retValues);
1479  return success();
1480  }
1481 };
1482 
1483 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1485  LogicalResult
1486  matchAndRewrite(NewOp op, OpAdaptor adaptor,
1487  ConversionPatternRewriter &rewriter) const override {
1488  Location loc = op.getLoc();
1489  const auto dstTp = getSparseTensorType(op.getResult());
1490  // Creating COO with NewOp is handled by direct IR codegen. All other cases
1491  // are handled by rewriting.
1492  if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1493  return failure();
1494 
1495  // Implement as follows:
1496  // %reader = @createCheckedSparseTensorReader(%filename)
1497  // %nse = @getSparseTensorNSE(%reader)
1498  // %coo = bufferization.alloc_tensor an ordered COO with
1499  // dst dim ordering, size_hint = %nse
1500  // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1501  // %values = sparse_tensor.values(%coo)
1502  // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1503  // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1504  // update storage specifier
1505  // @delSparseTensorReader(%reader)
1506  SmallVector<Value> dimSizesValues;
1507  Value dimSizesBuffer;
1508  Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1509  dimSizesValues, dimSizesBuffer);
1510 
1511  // Get the number of stored entries.
1512  const Type indexTp = rewriter.getIndexType();
1513  Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1514  {indexTp}, {reader}, EmitCInterface::Off)
1515  .getResult(0);
1516 
1517  // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1518  SmallVector<Value> lvlSizesValues;
1519  Value dim2lvlBuffer;
1520  Value lvl2dimBuffer;
1521  genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1522  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1523 
1524  // Construct allocation for each field.
1525  Value sizeHint = nse;
1526  SmallVector<Value> fields;
1527  createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1528  lvlSizesValues, fields);
1529 
1530  // Read the COO tensor data.
1531  MutSparseTensorDescriptor desc(dstTp, fields);
1532  Value xs = desc.getAOSMemRef();
1533  Value ys = desc.getValMemRef();
1534  const Type boolTp = rewriter.getIntegerType(1);
1535  const Type elemTp = dstTp.getElementType();
1536  const Type crdTp = dstTp.getCrdType();
1537  SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1539  primaryTypeFunctionSuffix(elemTp)};
1540  Value isSorted =
1541  createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1542  {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1543  EmitCInterface::On)
1544  .getResult(0);
1545 
1546  // If the destination tensor is a sorted COO, we need to sort the COO tensor
1547  // data if the input elements aren't sorted yet.
1548  const Level lvlRank = dstTp.getLvlRank();
1549  if (dstTp.isOrderedLvl(lvlRank - 1)) {
1550  Value kFalse = constantI1(rewriter, loc, false);
1551  Value notSorted = rewriter.create<arith::CmpIOp>(
1552  loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1553  scf::IfOp ifOp =
1554  rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1555  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1556  auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1557  rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1558  rewriter.getIndexAttr(0),
1559  SparseTensorSortKind::HybridQuickSort);
1560  rewriter.setInsertionPointAfter(ifOp);
1561  }
1562 
1563  // Set PosMemRef0[1] = nse.
1564  const Value c1 = constantIndex(rewriter, loc, 1);
1565  const Value posMemref0 = desc.getPosMemRef(0);
1566  const Type posTp = dstTp.getPosType();
1567  const Value posNse = genCast(rewriter, loc, nse, posTp);
1568  rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1569 
1570  // Update storage specifier.
1571  Value coordinatesSize = rewriter.create<arith::MulIOp>(
1572  loc, nse, constantIndex(rewriter, loc, lvlRank));
1573  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1574  coordinatesSize);
1575  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1576  std::nullopt, nse);
1577 
1578  // Release the sparse tensor reader.
1579  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1580  EmitCInterface::Off);
1581 
1582  // Replace operation with resulting memrefs.
1583  rewriter.replaceOpWithMultiple(op, {fields});
1584  return success();
1585  }
1586 };
1587 
1588 struct SparseHasRuntimeLibraryConverter
1589  : public OpConversionPattern<HasRuntimeLibraryOp> {
1591  LogicalResult
1592  matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1593  ConversionPatternRewriter &rewriter) const override {
1594  auto i1Type = rewriter.getI1Type();
1595  rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1596  op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1597  return success();
1598  }
1599 };
1600 
1601 } // namespace
1602 
1603 //===----------------------------------------------------------------------===//
1604 // Public method for populating conversion rules.
1605 //===----------------------------------------------------------------------===//
1606 
1607 /// Populates the given patterns list with conversion rules required for
1608 /// the sparsification of linear algebra operations.
1610  const TypeConverter &typeConverter, RewritePatternSet &patterns,
1611  bool createSparseDeallocs, bool enableBufferInitialization) {
1612  patterns.add<
1613  SparseAssembleOpConverter, SparseDisassembleOpConverter,
1614  SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1615  SparseCastConverter, SparseExtractSliceConverter,
1616  SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1617  SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1618  SparseSliceGetterOpConverter<ToSliceOffsetOp,
1619  StorageSpecifierKind::DimOffset>,
1620  SparseSliceGetterOpConverter<ToSliceStrideOp,
1621  StorageSpecifierKind::DimStride>,
1622  SparseToPositionsConverter, SparseToCoordinatesConverter,
1623  SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1624  SparseConvertConverter, SparseNewConverter,
1625  SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1626  typeConverter, patterns.getContext());
1627  patterns.add<SparseTensorDeallocConverter>(
1628  typeConverter, patterns.getContext(), createSparseDeallocs);
1629  patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1630  typeConverter, patterns.getContext(), enableBufferInitialization);
1631 }
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A helper class to simplify lowering operations with/without function calls.
Definition: CodegenUtils.h:78
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
Level getLvlRank() const
Returns the level-rank.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
bool isCompressedLT(LevelType lt)
Definition: Enums.h:415
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Definition: Enums.h:418
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326