MLIR  18.0.0git
SparseTensorCodegen.cpp
Go to the documentation of this file.
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library (other than for file I/O), and providing many more
14 // opportunities for subsequent compiler optimization of the generated code.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "CodegenUtils.h"
19 #include "SparseTensorDescriptor.h"
20 
32 
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 //===----------------------------------------------------------------------===//
39 // Helper methods.
40 //===----------------------------------------------------------------------===//
41 
42 /// Flattens a list of operands that may contain sparse tensors.
43 static void flattenOperands(ValueRange operands,
44  SmallVectorImpl<Value> &flattened) {
45  // In case of
46  // sparse_tensor, c, sparse_tensor
47  // ==>
48  // memref ..., c, memref ...
49  for (auto operand : operands) {
50  if (getSparseTensorEncoding(operand.getType())) {
51  auto tuple = getTuple(operand);
52  // An unrealized_conversion_cast will be inserted by type converter to
53  // inter-mix the gap between 1:N conversion between sparse tensors and
54  // fields. In this case, take the operands in the cast and replace the
55  // sparse tensor output with the flattened type array.
56  flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
57  } else {
58  flattened.push_back(operand);
59  }
60  }
61 }
62 
63 /// Generates a load with proper `index` typing.
64 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
65  idx = genCast(builder, loc, idx, builder.getIndexType());
66  return builder.create<memref::LoadOp>(loc, mem, idx);
67 }
68 
69 /// Generates a store with proper `index` typing and proper value.
70 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
71  Value idx) {
72  idx = genCast(builder, loc, idx, builder.getIndexType());
73  val = genCast(builder, loc, val,
74  cast<ShapedType>(mem.getType()).getElementType());
75  builder.create<memref::StoreOp>(loc, val, mem, idx);
76 }
77 
78 /// Creates a straightforward counting for-loop.
79 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
81  Value lower = Value()) {
82  Type indexType = builder.getIndexType();
83  if (!lower)
84  lower = constantZero(builder, loc, indexType);
85  Value one = constantOne(builder, loc, indexType);
86  scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
87  for (unsigned i = 0, e = fields.size(); i < e; i++)
88  fields[i] = forOp.getRegionIterArg(i);
89  builder.setInsertionPointToStart(forOp.getBody());
90  return forOp;
91 }
92 
93 /// Creates a push back operation.
94 static void createPushback(OpBuilder &builder, Location loc,
96  SparseTensorFieldKind kind, std::optional<Level> lvl,
97  Value value, Value repeat = Value()) {
98  Type etp = desc.getMemRefElementType(kind, lvl);
99  Value field = desc.getMemRefField(kind, lvl);
100  StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
101 
102  auto pushBackOp = builder.create<PushBackOp>(
103  loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
104  genCast(builder, loc, value, etp), repeat);
105 
106  desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
107  desc.setSpecifierField(builder, loc, specFieldKind, lvl,
108  pushBackOp.getNewSize());
109 }
110 
111 /// Generates code that allocates a sparse storage scheme for given rank.
112 static void allocSchemeForRank(OpBuilder &builder, Location loc,
113  MutSparseTensorDescriptor desc, Level startLvl) {
114  const SparseTensorType stt(desc.getRankedTensorType());
115  Value linear = constantIndex(builder, loc, 1);
116  const Level lvlRank = stt.getLvlRank();
117  for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
118  const auto lt = stt.getLvlType(lvl);
119  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
120  // Append linear x positions, initialized to zero. Since each compressed
121  // dimension initially already has a single zero entry, this maintains
122  // the desired "linear + 1" length property at all times. For loose
123  // compression, we multiply linear by two in order to append both the
124  // lo/hi positions.
125  Value posZero = constantZero(builder, loc, stt.getPosType());
126  if (isLooseCompressedLT(lt)) {
127  Value two = constantIndex(builder, loc, 2);
128  linear = builder.create<arith::MulIOp>(loc, linear, two);
129  }
130  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
131  /*value=*/posZero, /*repeat=*/linear);
132  return;
133  } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
134  return; // nothing to do
135  }
136  // Keep compounding the size, but nothing needs to be initialized
137  // at this level. We will eventually reach a compressed level or
138  // otherwise the values array for the from-here "all-dense" case.
139  assert(isDenseLT(lt));
140  Value size = desc.getLvlSize(builder, loc, lvl);
141  linear = builder.create<arith::MulIOp>(loc, linear, size);
142  }
143  // Reached values array so prepare for an insertion.
144  Value valZero = constantZero(builder, loc, stt.getElementType());
146  std::nullopt, /*value=*/valZero, /*repeat=*/linear);
147 }
148 
149 /// Creates allocation operation.
151  MemRefType memRefType, Value sz,
152  bool enableInit) {
153  Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
154  Type elemType = memRefType.getElementType();
155  if (enableInit) {
156  Value fillValue = constantZero(builder, loc, elemType);
157  builder.create<linalg::FillOp>(loc, fillValue, buffer);
158  }
159  return buffer;
160 }
161 
162 /// Creates the dim sizes array, filling in from dynamic sizes.
163 static void createDimSizes(OpBuilder &builder, Location loc,
164  SparseTensorType stt, ValueRange dynSizes,
165  /*out*/ SmallVectorImpl<Value> &dimSizesValues) {
166  const Dimension dimRank = stt.getDimRank();
167  dimSizesValues.clear();
168  dimSizesValues.reserve(dimRank);
169  unsigned i = 0;
170  for (const Size sz : stt.getDimShape())
171  dimSizesValues.push_back(ShapedType::isDynamic(sz)
172  ? dynSizes[i++]
173  : constantIndex(builder, loc, sz));
174 }
175 
176 /// Creates allocation for each field in sparse tensor type. Note that
177 /// for all dynamic memrefs in the sparse tensor stroage layout, the
178 /// memory size is really the capacity of the "vector", while the actual
179 /// size resides in the sizes array.
180 static void createAllocFields(OpBuilder &builder, Location loc,
181  SparseTensorType stt, bool enableInit,
182  Value sizeHint,
183  SmallVectorImpl<Value> &lvlSizesValues,
184  /*out*/ SmallVectorImpl<Value> &fields) {
185  Level lvlRank = stt.getLvlRank();
186  // Set up some heuristic sizes. We try to set the initial
187  // size based on available information. Otherwise we just
188  // initialize a few elements to start the reallocation chain.
189  // TODO: refine this
190  Value posHeuristic, crdHeuristic, valHeuristic;
191  if (stt.isAllDense()) {
192  valHeuristic = lvlSizesValues[0];
193  for (Level lvl = 1; lvl < lvlRank; lvl++)
194  valHeuristic =
195  builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196  } else if (sizeHint) {
197  if (getCOOStart(stt.getEncoding()) == 0) {
198  posHeuristic = constantIndex(builder, loc, 2);
199  crdHeuristic = builder.create<arith::MulIOp>(
200  loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
201  } else if (lvlRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
202  posHeuristic = builder.create<arith::AddIOp>(
203  loc, sizeHint, constantIndex(builder, loc, 1));
204  crdHeuristic = sizeHint;
205  } else {
206  posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
207  }
208  valHeuristic = sizeHint;
209  } else {
210  posHeuristic = crdHeuristic = valHeuristic =
211  constantIndex(builder, loc, 16);
212  }
213  // Initializes all fields. An initial storage specifier and allocated
214  // positions/coordinates/values memrefs (with heuristic capacity).
216  stt,
217  [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
218  enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
219  Level /*lvl*/, LevelType /*lt*/) -> bool {
220  assert(fields.size() == fIdx);
221  Value field;
222  switch (fKind) {
224  field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
225  break;
227  field = createAllocation(builder, loc, cast<MemRefType>(fType),
228  posHeuristic, enableInit);
229  break;
231  field = createAllocation(builder, loc, cast<MemRefType>(fType),
232  crdHeuristic, enableInit);
233  break;
235  field = createAllocation(builder, loc, cast<MemRefType>(fType),
236  valHeuristic, enableInit);
237  break;
238  }
239  assert(field);
240  fields.push_back(field);
241  // Returns true to continue the iteration.
242  return true;
243  });
244  // Initialize the storage scheme to an empty tensor. Sets the lvlSizes
245  // and gives all position fields an initial zero entry, so that it is
246  // easier to maintain the "linear + 1" length property.
247  MutSparseTensorDescriptor desc(stt, fields);
248  Value posZero = constantZero(builder, loc, stt.getPosType());
249  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
250  desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
251  const auto lt = stt.getLvlType(lvl);
252  if (isCompressedLT(lt) || isLooseCompressedLT(lt))
253  createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
254  /*value=*/posZero);
255  }
256  allocSchemeForRank(builder, loc, desc, /*rank=*/0);
257 }
258 
259 /// Helper method that generates block specific to compressed case:
260 ///
261 /// // given: parentPos = posCursor[lvl-1]
262 /// pstart = desc.positions[lvl][parentPos]
263 /// pstop = desc.positions[lvl][parentPos+1]
264 /// plast = pstop - 1
265 /// msz = desc.coordinates[lvl].size()
266 /// if (pstart < pstop) {
267 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
268 /// } else { // first insertion
269 /// isPresent = false
270 /// desc.positions[lvl][parentPos] = msz
271 /// }
272 /// if (isPresent) { // coordinate is already present
273 /// pnext = plast
274 /// } else {
275 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
276 /// desc.positions[lvl][parentPos+1] = msz+1
277 /// pnext = msz
278 /// <prepare level lvl+1>
279 /// }
280 /// posCursor[lvl] = pnext
281 static Value genCompressed(OpBuilder &builder, Location loc,
282  MutSparseTensorDescriptor desc, ValueRange lvlCoords,
283  Value /*unused*/, Value parentPos, Level lvl) {
284  const SparseTensorType stt(desc.getRankedTensorType());
285  const Level lvlRank = stt.getLvlRank();
286  assert(lvl < lvlRank && "Level is out of bounds");
287  assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
288  "Level-rank mismatch");
289  SmallVector<Type> types;
290  Type indexType = builder.getIndexType();
291  Type boolType = builder.getIntegerType(1);
292  unsigned crdFidx;
293  unsigned crdStride;
294  std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
295  const Value one = constantIndex(builder, loc, 1);
296  const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
297  const Value positionsAtLvl = desc.getPosMemRef(lvl);
298  const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
299  const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
300  const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
301  const Value crdStrideC =
302  crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
303  const Value msz =
304  crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
305  : crdMsz;
306  const Value plast = builder.create<arith::SubIOp>(
307  loc, genCast(builder, loc, pstop, indexType), one);
308  // Conditional expression.
309  Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
310  pstart, pstop);
311  types.push_back(boolType);
312  scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
313  types.pop_back();
314  builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
315  Value crd =
316  genLoad(builder, loc, desc.getMemRefField(crdFidx),
317  crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
318  : plast);
319  Value eq = builder.create<arith::CmpIOp>(
320  loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
321  lvlCoords[lvl]);
322  builder.create<scf::YieldOp>(loc, eq);
323  builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
324  if (lvl > 0)
325  genStore(builder, loc, msz, positionsAtLvl, parentPos);
326  builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
327  builder.setInsertionPointAfter(ifOp1);
328  // If present construct. Note that for a non-unique dimension level, we
329  // simply set the condition to false and rely on CSE/DCE to clean up the IR.
330  //
331  // TODO: generate less temporary IR?
332  //
333  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
334  types.push_back(desc.getField(i).getType());
335  types.push_back(indexType);
336  const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
337  : constantI1(builder, loc, false);
338  scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
339  // If present (fields unaffected, update pnext to plast).
340  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
341 
342  // FIXME: This does not looks like a clean way, but probably the most
343  // efficient way.
344  desc.getFields().push_back(plast);
345  builder.create<scf::YieldOp>(loc, desc.getFields());
346  desc.getFields().pop_back();
347 
348  // If !present (changes fields, update pnext).
349  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
350  Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
351  genStore(builder, loc, mszp1, positionsAtLvl, pp1);
352  createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
353  /*value=*/lvlCoords[lvl]);
354  // Prepare the next level "as needed".
355  if ((lvl + 1) < lvlRank)
356  allocSchemeForRank(builder, loc, desc, lvl + 1);
357 
358  desc.getFields().push_back(msz);
359  builder.create<scf::YieldOp>(loc, desc.getFields());
360  desc.getFields().pop_back();
361 
362  // Update fields and return next pos.
363  builder.setInsertionPointAfter(ifOp2);
364  unsigned o = 0;
365  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
366  desc.setField(i, ifOp2.getResult(o++));
367  return ifOp2.getResult(o);
368 }
369 
370 /// Generates insertion finalization code.
371 static void genEndInsert(OpBuilder &builder, Location loc,
372  SparseTensorDescriptor desc) {
373  const SparseTensorType stt(desc.getRankedTensorType());
374  const Level lvlRank = stt.getLvlRank();
375  for (Level lvl = 0; lvl < lvlRank; lvl++) {
376  const auto lt = stt.getLvlType(lvl);
377  if (isCompressedLT(lt)) {
378  // Compressed dimensions need a position cleanup for all entries
379  // that were not visited during the insertion pass.
380  //
381  // TODO: avoid cleanup and keep compressed scheme consistent at all
382  // times?
383  //
384  if (lvl > 0) {
385  Type posType = stt.getPosType();
386  Value posMemRef = desc.getPosMemRef(lvl);
387  Value hi = desc.getPosMemSize(builder, loc, lvl);
388  Value zero = constantIndex(builder, loc, 0);
389  Value one = constantIndex(builder, loc, 1);
390  // Vector of only one, but needed by createFor's prototype.
391  SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
392  scf::ForOp loop = createFor(builder, loc, hi, inits, one);
393  Value i = loop.getInductionVar();
394  Value oldv = loop.getRegionIterArg(0);
395  Value newv = genLoad(builder, loc, posMemRef, i);
396  Value posZero = constantZero(builder, loc, posType);
397  Value cond = builder.create<arith::CmpIOp>(
398  loc, arith::CmpIPredicate::eq, newv, posZero);
399  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
400  cond, /*else*/ true);
401  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
402  genStore(builder, loc, oldv, posMemRef, i);
403  builder.create<scf::YieldOp>(loc, oldv);
404  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
405  builder.create<scf::YieldOp>(loc, newv);
406  builder.setInsertionPointAfter(ifOp);
407  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
408  builder.setInsertionPointAfter(loop);
409  }
410  } else {
411  assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
412  is2OutOf4LT(lt));
413  }
414  }
415 }
416 
417 /// Generates a subview into the sizes.
418 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
419  Value sz) {
420  auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
421  return builder
422  .create<memref::SubViewOp>(
423  loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
424  ValueRange{}, ValueRange{sz}, ValueRange{},
425  ArrayRef<int64_t>{0}, // static offset
426  ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
427  ArrayRef<int64_t>{1}) // static stride
428  .getResult();
429 }
430 
431 /// Creates the reassociation array.
433  ReassociationIndices reassociation;
434  for (int i = 0, e = srcTp.getRank(); i < e; i++)
435  reassociation.push_back(i);
436  return reassociation;
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // Codegen rules.
441 //===----------------------------------------------------------------------===//
442 
443 namespace {
444 
445 /// Helper class to help lowering sparse_tensor.insert operation.
446 class SparseInsertGenerator
447  : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
448 public:
449  SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
450  bool genCall)
451  : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
452 
453  /// Generates code along an insertion path without the need for a "cursor".
454  /// This current insertion strategy comes at the expense of some testing
455  /// overhead for each insertion. The strategy will be optimized later for
456  /// common insertion patterns. The current insertion strategy also assumes
457  /// insertions occur in "a reasonable order" that enables building the
458  /// storage scheme in an appending/inserting kind of fashion (i.e. no
459  /// in-between insertions that need data movement). The implementation
460  /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
461  ///
462  /// TODO: better unord/not-unique; also generalize, optimize, specialize!
463  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
464  OpBuilder &builder, Location loc) {
465  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
466  const Level lvlRank = stt.getLvlRank();
467  // Extract fields and coordinates from args.
468  SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
469  MutSparseTensorDescriptor desc(stt, fields);
470  const SmallVector<Value> coords =
471  llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
472  Value value = args.back();
473  Value parentPos = constantZero(builder, loc, builder.getIndexType());
474  // Generate code for every level.
475  for (Level lvl = 0; lvl < lvlRank; lvl++) {
476  const auto lt = stt.getLvlType(lvl);
477  if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
478  // Create:
479  // if (!present) {
480  // coordinates[lvl].push_back(coords[lvl])
481  // <update positions and prepare level lvl + 1>
482  // }
483  // positions[lvl] = coordinates.size() - 1
484  // <insert @ positions[lvl] at next level lvl + 1>
485  if (isLooseCompressedLT(lt)) {
486  Value two = constantIndex(builder, loc, 2);
487  parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
488  }
489  parentPos =
490  genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
491  } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
492  // Create:
493  // coordinates[lvl].push_back(coords[lvl])
494  // positions[lvl] = positions[lvl-1]
495  // <insert @ positions[lvl] at next level lvl + 1>
497  lvl, /*value=*/coords[lvl]);
498  } else {
499  assert(isDenseLT(lt));
500  // Construct the new position as:
501  // positions[lvl] = size * positions[lvl-1] + coords[lvl]
502  // <insert @ positions[lvl] at next level lvl + 1>
503  Value size = desc.getLvlSize(builder, loc, lvl);
504  Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
505  parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
506  }
507  }
508  // Reached the actual value append/insert.
509  if (!stt.isDenseLvl(lvlRank - 1))
511  std::nullopt, value);
512  else
513  genStore(builder, loc, value, desc.getValMemRef(), parentPos);
514  return fields;
515  }
516 
517  std::string getMangledFuncName() {
518  // The mangled name of the function has this format:
519  // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
520  constexpr const char kInsertFuncNamePrefix[] = "_insert_";
521  const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
522  SmallString<32> nameBuffer;
523  llvm::raw_svector_ostream nameOstream(nameBuffer);
524  nameOstream << kInsertFuncNamePrefix;
525  const Level lvlRank = stt.getLvlRank();
526  for (Level l = 0; l < lvlRank; l++) {
527  std::string lvlType = toMLIRString(stt.getLvlType(l));
528  // Replace/remove punctuations in level properties.
529  std::replace_if(
530  lvlType.begin(), lvlType.end(),
531  [](char c) { return c == '(' || c == ','; }, '_');
532  llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
533  nameOstream << lvlType << "_";
534  }
535  // Static dim sizes are used in the generated code while dynamic sizes are
536  // loaded from the dimSizes buffer. This is the reason for adding the shape
537  // to the function name.
538  for (const auto sz : stt.getDimShape())
539  nameOstream << sz << "_";
540  // Permutation information is also used in generating insertion.
541  if (!stt.isIdentity())
542  nameOstream << stt.getDimToLvl() << "_";
543  nameOstream << stt.getElementType() << "_";
544  nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
545  return nameOstream.str().str();
546  }
547 
548 private:
549  TensorType rtp;
550 };
551 
552 /// Sparse tensor storage conversion rule for returns.
553 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
554 public:
557  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
558  ConversionPatternRewriter &rewriter) const override {
559  SmallVector<Value> flattened;
560  flattenOperands(adaptor.getOperands(), flattened);
561  // Create a return with the flattened value extracted from sparse tensors.
562  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
563  return success();
564  }
565 };
566 
567 /// Sparse tensor storage conversion rule for calls.
568 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
569 public:
570  // The default CallOp converter can not handle 1:N type conversion.
573  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
574  ConversionPatternRewriter &rewriter) const override {
575  Location loc = op.getLoc();
576  // In case of:
577  // sparse_tensor, f, sparse_tensor = call @foo(...)
578  // ==>
579  // memref..., f, memref = call @foo(...) replace with
580  // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
581  SmallVector<Type> finalRetTy;
582  if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
583  return failure();
584 
585  // (1) Generates new call with flattened return value.
586  SmallVector<Value> flattened;
587  flattenOperands(adaptor.getOperands(), flattened);
588  auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
589  finalRetTy, flattened);
590  // (2) Create cast operation for sparse tensor returns.
591  SmallVector<Value> castedRet;
592  // Tracks the offset of current return value (of the original call)
593  // relative to the new call (after sparse tensor flattening);
594  unsigned retOffset = 0;
595  // Temporal buffer to hold the flattened list of type for
596  // a sparse tensor.
597  SmallVector<Type> sparseFlat;
598  for (auto ret : op.getResults()) {
599  assert(retOffset < newCall.getNumResults());
600  auto retType = ret.getType();
601  if (failed(typeConverter->convertType(retType, sparseFlat)))
602  llvm_unreachable("Failed to convert type in sparse tensor codegen");
603 
604  // Converted types can not be empty when the type conversion succeed.
605  assert(!sparseFlat.empty());
606  if (sparseFlat.size() > 1) {
607  auto flatSize = sparseFlat.size();
609  newCall.result_begin() + retOffset,
610  newCall.result_begin() + retOffset + flatSize));
611  castedRet.push_back(genTuple(rewriter, loc, retType, fields));
612  retOffset += flatSize;
613  } else {
614  // If this is an 1:1 conversion, no need for casting.
615  castedRet.push_back(newCall.getResult(retOffset));
616  retOffset++;
617  }
618  sparseFlat.clear();
619  }
620 
621  assert(castedRet.size() == op.getNumResults());
622  rewriter.replaceOp(op, castedRet);
623  return success();
624  }
625 };
626 
627 /// Sparse codegen rule for level accesses.
628 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
629 public:
632  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
633  ConversionPatternRewriter &rewriter) const override {
634  std::optional<int64_t> lvl = op.getConstantLvlIndex();
635  if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
636  return failure();
637 
638  auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
639  auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
640 
641  rewriter.replaceOp(op, sz);
642  return success();
643  }
644 };
645 
646 // TODO: use a new SortCOO operation here instead of reusing convert op.
647 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
650  matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
651  ConversionPatternRewriter &rewriter) const override {
652  Location loc = op.getLoc();
653  MLIRContext *ctx = op.getContext();
654 
655  SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
656  SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
657 
658  // Should have been verified.
659  assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
662  assert(dstStt.hasSameDimToLvl(srcStt));
663 
664  // We don't need a mutable descriptor here as we perform sorting in-place.
665  auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
666  auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
667  auto crd = desc.getAOSMemRef();
668  auto val = desc.getValMemRef();
669 
670  // Otherwise we need another data shuffle and a non-identity map.
671  assert(dstStt.hasSameDimToLvl(srcStt));
672  (void)dstStt; // to silence warning when assertion is disabled
673 
674  auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
675 
676  rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
677  rewriter.getIndexAttr(0), op.getAlgorithm());
678 
679  // Since we do in-place sorting, the destinate tensor will have the same set
680  // of memrefs as the source tensor.
681  rewriter.replaceOp(op, adaptor.getInputCoo());
682  return success();
683  }
684 };
685 
686 template <typename Op, StorageSpecifierKind kind>
687 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
688 public:
691  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
692  ConversionPatternRewriter &rewriter) const override {
693  // Simply lowers to specifer.get <field> operation.
694  auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
695  auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
696  op.getDim().getZExtValue());
697 
698  rewriter.replaceOp(op, v);
699  return success();
700  }
701 };
702 
703 /// Sparse codegen rule for trivial tensor casts.
704 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
705 public:
708  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
709  ConversionPatternRewriter &rewriter) const override {
710  // Only rewrite identically annotated source/dest.
711  auto encDst = getSparseTensorEncoding(op.getType());
712  auto encSrc = getSparseTensorEncoding(op.getSource().getType());
713  if (!encDst || encDst != encSrc)
714  return failure();
715  rewriter.replaceOp(op, adaptor.getOperands());
716  return success();
717  }
718 };
719 
720 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
721 public:
724  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
725  ConversionPatternRewriter &rewriter) const override {
726  // Simply fold the operation.
727  rewriter.replaceOp(op, adaptor.getSource());
728  return success();
729  }
730 };
731 
732 /// Sparse codegen rule for the alloc operator.
733 class SparseTensorAllocConverter
734  : public OpConversionPattern<bufferization::AllocTensorOp> {
735 public:
737  SparseTensorAllocConverter(TypeConverter &typeConverter, MLIRContext *context,
738  bool enableInit)
739  : OpConversionPattern(typeConverter, context),
740  enableBufferInitialization(enableInit) {}
741 
743  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
744  ConversionPatternRewriter &rewriter) const override {
745  const auto resType = getSparseTensorType(op);
746  if (!resType.hasEncoding())
747  return failure();
748 
749  Location loc = op.getLoc();
750  // Deal with copy.
751  if (op.getCopy()) {
752  auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
753  SmallVector<Value> fields;
754  fields.reserve(desc.getNumFields());
755  // Memcpy on memref fields.
756  for (auto field : desc.getMemRefFields()) {
757  auto memrefTp = cast<MemRefType>(field.getType());
758  auto size = rewriter.create<memref::DimOp>(loc, field, 0);
759  auto copied =
760  rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
761  rewriter.create<memref::CopyOp>(loc, field, copied);
762  fields.push_back(copied);
763  }
764  // Reuses specifier.
765  fields.push_back(desc.getSpecifier());
766  assert(fields.size() == desc.getNumFields());
767  rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
768  return success();
769  }
770 
771  if (!resType.isIdentity()) {
772  return rewriter.notifyMatchFailure(
773  op, "try run --sparse-reinterpret-map before codegen");
774  }
775  // Level size equals to dimension size since lvl2dim map is an identity map.
776  SmallVector<Value> lvlSizesValues;
777  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
778  /*dimSizesValues=*/lvlSizesValues);
779 
780  // Construct allocation for each field.
781  Value sizeHint = op.getSizeHint();
782  SmallVector<Value> fields;
783  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
784  sizeHint, lvlSizesValues, fields);
785 
786  // Replace operation with resulting memrefs.
787  rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
788  return success();
789  }
790 
791 private:
792  bool enableBufferInitialization;
793 };
794 
795 /// Sparse codegen rule for the empty tensor operator.
796 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
797 public:
799  SparseTensorEmptyConverter(TypeConverter &typeConverter, MLIRContext *context,
800  bool enableInit)
801  : OpConversionPattern(typeConverter, context),
802  enableBufferInitialization(enableInit) {}
803 
805  matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
806  ConversionPatternRewriter &rewriter) const override {
807  const auto resType = getSparseTensorType(op);
808  if (!resType.hasEncoding())
809  return failure();
810 
811  if (!resType.isIdentity()) {
812  return rewriter.notifyMatchFailure(
813  op, "try run --sparse-reinterpret-map before codegen");
814  }
815 
816  Location loc = op.getLoc();
817  // Level size equals to dimension size since lvl2dim map is an identity map.
818  SmallVector<Value> lvlSizesValues;
819  createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
820  /*dimSizesValues=*/lvlSizesValues);
821  // Construct allocation for each field.
822  Value sizeHint; // none
823  SmallVector<Value> fields;
824  createAllocFields(rewriter, loc, resType, enableBufferInitialization,
825  sizeHint, lvlSizesValues, fields);
826 
827  // Replace operation with resulting memrefs.
828  rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
829  return success();
830  }
831 
832 private:
833  bool enableBufferInitialization;
834 };
835 
836 /// Sparse codegen rule for the dealloc operator.
837 class SparseTensorDeallocConverter
838  : public OpConversionPattern<bufferization::DeallocTensorOp> {
839 public:
841  SparseTensorDeallocConverter(TypeConverter &typeConverter,
842  MLIRContext *context, bool createDeallocs)
843  : OpConversionPattern(typeConverter, context),
844  createDeallocs(createDeallocs) {}
845 
847  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
848  ConversionPatternRewriter &rewriter) const override {
849  auto enc = getSparseTensorEncoding(op.getTensor().getType());
850  if (!enc)
851  return failure();
852 
853  // If user requests not to deallocate sparse tensors, simply erase the
854  // operation.
855  if (createDeallocs) {
856  // Replace the sparse tensor deallocation with field deallocations.
857  Location loc = op.getLoc();
858  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
859  for (auto input : desc.getMemRefFields())
860  // Deallocate every buffer used to store the sparse tensor handler.
861  rewriter.create<memref::DeallocOp>(loc, input);
862  }
863  rewriter.eraseOp(op);
864  return success();
865  }
866 
867 private:
868  const bool createDeallocs;
869 };
870 
871 /// Sparse codegen rule for tensor rematerialization.
872 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
873 public:
876  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
877  ConversionPatternRewriter &rewriter) const override {
878  // Prepare descriptor.
879  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
880  // Generate optional insertion finalization code.
881  if (op.getHasInserts())
882  genEndInsert(rewriter, op.getLoc(), desc);
883  // Replace operation with resulting memrefs.
884  rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
885  return success();
886  }
887 };
888 
889 /// Sparse codegen rule for the expand op.
890 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
891 public:
894  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
895  ConversionPatternRewriter &rewriter) const override {
896  if (!getSparseTensorEncoding(op.getTensor().getType()))
897  return failure();
898  Location loc = op->getLoc();
899  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
900  const auto srcType = getSparseTensorType(op.getTensor());
901  Type eltType = srcType.getElementType();
902  Type boolType = rewriter.getIntegerType(1);
903  Type idxType = rewriter.getIndexType();
904  // All initialization should be done on entry of the loop nest.
905  rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
906 
907  // Determine the size for access expansion (always the innermost stored
908  // level size).
909  const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
910  // Generate a memref for `sz` elements of type `t`.
911  const auto genAlloc = [&](Type t) {
912  const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
913  return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
914  };
915  // Allocate temporary buffers for values/filled-switch and added.
916  // We do not use stack buffers for this, since the expanded size may
917  // be rather large (as it envelops a single expanded dense dimension).
918  Value values = genAlloc(eltType);
919  Value filled = genAlloc(boolType);
920  Value added = genAlloc(idxType);
921  Value zero = constantZero(rewriter, loc, idxType);
922  // Reset the values/filled-switch to all-zero/false. Note that this
923  // introduces an O(N) operation into the computation, but this reset
924  // operation is amortized over the innermost loops for the access
925  // pattern expansion. As noted in the operation doc, we would like
926  // to amortize this setup cost even between kernels.
927  rewriter.create<linalg::FillOp>(
928  loc, ValueRange{constantZero(rewriter, loc, eltType)},
929  ValueRange{values});
930  rewriter.create<linalg::FillOp>(
931  loc, ValueRange{constantZero(rewriter, loc, boolType)},
932  ValueRange{filled});
933  // Replace expansion op with these buffers and initial coordinate.
934  assert(op.getNumResults() == 4);
935  rewriter.replaceOp(op, {values, filled, added, zero});
936  return success();
937  }
938 };
939 
940 /// Sparse codegen rule for the compress operator.
941 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
942 public:
945  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
946  ConversionPatternRewriter &rewriter) const override {
947  Location loc = op->getLoc();
948  SmallVector<Value> fields;
949  auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
950  Value values = adaptor.getValues();
951  Value filled = adaptor.getFilled();
952  Value added = adaptor.getAdded();
953  Value count = adaptor.getCount();
954  const SparseTensorType dstType(desc.getRankedTensorType());
955  Type eltType = dstType.getElementType();
956 
957  // If the innermost level is ordered, we need to sort the coordinates
958  // in the "added" array prior to applying the compression.
959  if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
960  rewriter.create<SortOp>(
961  loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
962  rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
963  // While performing the insertions, we also need to reset the elements
964  // of the values/filled-switch by only iterating over the set elements,
965  // to ensure that the runtime complexity remains proportional to the
966  // sparsity of the expanded access pattern.
967  //
968  // Generate
969  // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
970  // crd = added[i];
971  // value = values[crd];
972  // insert({lvlCoords, crd}, value);
973  // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
974  // values[crd] = 0;
975  // filled[crd] = false;
976  // yield new_memrefs
977  // }
978  scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
979  Value i = loop.getInductionVar();
980 
981  Value crd = genLoad(rewriter, loc, added, i);
982  Value value = genLoad(rewriter, loc, values, crd);
983  SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
984  SmallVector<Type> flatSpTensorTps = llvm::to_vector(
985  llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
986  params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
987  params.push_back(crd);
988  params.push_back(value);
989  SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
990  params, /*genCall=*/true);
991  SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
992  genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
993  genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
994  rewriter.create<scf::YieldOp>(loc, insertRet);
995 
996  rewriter.setInsertionPointAfter(loop);
997  Value result = genTuple(rewriter, loc, dstType, loop->getResults());
998  // Deallocate the buffers on exit of the full loop nest.
999  Operation *parent = getTop(op);
1000  rewriter.setInsertionPointAfter(parent);
1001  rewriter.create<memref::DeallocOp>(loc, values);
1002  rewriter.create<memref::DeallocOp>(loc, filled);
1003  rewriter.create<memref::DeallocOp>(loc, added);
1004  // Replace operation with resulting memrefs.
1005  rewriter.replaceOp(op, result);
1006  return success();
1007  }
1008 };
1009 
1010 /// Sparse codegen rule for the insert operator.
1011 class SparseInsertConverter : public OpConversionPattern<InsertOp> {
1012 public:
1015  matchAndRewrite(InsertOp op, OpAdaptor adaptor,
1016  ConversionPatternRewriter &rewriter) const override {
1017  Location loc = op.getLoc();
1018  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1019  TypeRange flatSpTensorTps = desc.getFields().getTypes();
1020  SmallVector<Value> params = llvm::to_vector(desc.getFields());
1021  params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
1022  params.push_back(adaptor.getValue());
1023  SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1024  params, /*genCall=*/true);
1025  SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1026  // Replace operation with resulting memrefs.
1027  rewriter.replaceOp(op,
1028  genTuple(rewriter, loc, op.getTensor().getType(), ret));
1029  return success();
1030  }
1031 };
1032 
1033 /// Sparse codegen rule for position accesses.
1034 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1035 public:
1036  using OpAdaptor = typename ToPositionsOp::Adaptor;
1039  matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1040  ConversionPatternRewriter &rewriter) const override {
1041  // Replace the requested position access with corresponding field.
1042  // The cast_op is inserted by type converter to intermix 1:N type
1043  // conversion.
1044  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1045  rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
1046  return success();
1047  }
1048 };
1049 
1050 /// Sparse codegen rule for accessing the coordinates arrays.
1051 class SparseToCoordinatesConverter
1052  : public OpConversionPattern<ToCoordinatesOp> {
1053 public:
1054  using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1057  matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1058  ConversionPatternRewriter &rewriter) const override {
1059  // Replace the requested coordinates access with corresponding field.
1060  // The cast_op is inserted by type converter to intermix 1:N type
1061  // conversion.
1062  Location loc = op.getLoc();
1063  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1064  Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
1065 
1066  // Insert a cast to bridge the actual type to the user expected type. If the
1067  // actual type and the user expected type aren't compatible, the compiler or
1068  // the runtime will issue an error.
1069  Type resType = op.getResult().getType();
1070  if (resType != field.getType())
1071  field = rewriter.create<memref::CastOp>(loc, resType, field);
1072  rewriter.replaceOp(op, field);
1073 
1074  return success();
1075  }
1076 };
1077 
1078 /// Sparse codegen rule for accessing the linear coordinates buffer.
1079 class SparseToCoordinatesBufferConverter
1080  : public OpConversionPattern<ToCoordinatesBufferOp> {
1081 public:
1082  using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1085  matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1086  ConversionPatternRewriter &rewriter) const override {
1087  // Replace the requested coordinates access with corresponding field.
1088  // The cast_op is inserted by type converter to intermix 1:N type
1089  // conversion.
1090  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1091  rewriter.replaceOp(op, desc.getAOSMemRef());
1092 
1093  return success();
1094  }
1095 };
1096 
1097 /// Sparse codegen rule for value accesses.
1098 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1099 public:
1100  using OpAdaptor = typename ToValuesOp::Adaptor;
1103  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1104  ConversionPatternRewriter &rewriter) const override {
1105  // Replace the requested values access with corresponding field.
1106  // The cast_op is inserted by type converter to intermix 1:N type
1107  // conversion.
1108  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1109  rewriter.replaceOp(op, desc.getValMemRef());
1110  return success();
1111  }
1112 };
1113 
1114 /// Sparse codegen rule for the convert operator.
1115 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1116 public:
1119  matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1120  ConversionPatternRewriter &rewriter) const override {
1121  SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1122  SparseTensorEncodingAttr encSrc =
1123  getSparseTensorEncoding(op.getSource().getType());
1124  // The output tensor can not be a slice and those cases should have been
1125  // rejected by ConvertOp::verify() already.
1126  assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1127  // Different encoding (except for different bitwidth) should be handled by
1128  // rewriting.
1129  // We need further rewrites if the input tensor is a slice too.
1130  if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1131  encSrc.isSlice()) {
1132  return failure();
1133  }
1134 
1135  Type retElemTp = op.getResult().getType().getElementType();
1136  Type srcElemTp = op.getSource().getType().getElementType();
1137  // Fold the trivial cases.
1138  if (retElemTp == srcElemTp && encDst == encSrc) {
1139  rewriter.replaceOp(op, adaptor.getSource());
1140  return success();
1141  }
1142  //
1143  // Do element-wise type conversion without using InsertOp.
1144  //
1145  // for each memref in srcTensor:
1146  // dst = memref.alloc
1147  // if srcMemRefType != dstMemRefType:
1148  // for every dst[i] = cast(src[i])
1149  // else:
1150  // dst = memref.copy(src)
1151  Location loc = op.getLoc();
1152  auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
1153  SmallVector<Value> fields;
1155  SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1156  [&rewriter, &fields, srcDesc,
1157  loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1158  LevelType /*lt*/) -> bool {
1159  // Simply reuses the storage specifier as it is an SSA value.
1160  if (fKind == SparseTensorFieldKind::StorageSpec) {
1161  fields.push_back(srcDesc.getSpecifier());
1162  } else {
1163  // Allocates new memrefs
1164  Value srcMem = srcDesc.getMemRefField(fIdx);
1165  // TODO: We can instead use the actual memSize in specifier, that
1166  // would require a subViewOp to avoid overflow when copying
1167  // values.
1168  Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1169  auto dstMem = rewriter.create<memref::AllocOp>(
1170  loc, cast<MemRefType>(fTp), sz);
1171  if (fTp != srcMem.getType()) {
1172  // Converts elements type.
1173  scf::buildLoopNest(
1174  rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1175  constantIndex(rewriter, loc, 1),
1176  [srcMem, &dstMem](OpBuilder &builder, Location loc,
1177  ValueRange ivs) {
1178  Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1179  Value casted = genCast(builder, loc, v,
1180  dstMem.getType().getElementType());
1181  builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1182  });
1183  } else {
1184  // TODO: We can even reuse the same memref for the new tensor,
1185  // but that requires a `ref-counting` based memory management
1186  // for shared memrefs between multiple sparse tensors.
1187  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1188  }
1189  fields.push_back(dstMem);
1190  }
1191  return true;
1192  });
1193 
1194  rewriter.replaceOp(
1195  op, genTuple(rewriter, loc, op.getResult().getType(), fields));
1196  return success();
1197  }
1198 };
1199 
1200 class SparseExtractSliceConverter
1201  : public OpConversionPattern<tensor::ExtractSliceOp> {
1202 public:
1205  matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1206  ConversionPatternRewriter &rewriter) const override {
1207  Location loc = op.getLoc();
1208  MLIRContext *ctx = op.getContext();
1209  auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1210  auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1211  // TODO: We should check these in ExtractSliceOp::verify.
1212  if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1213  return failure();
1214  assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1215 
1216  SmallVector<Value> fields;
1217  auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
1218 
1219  auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1220  loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1221  desc.setSpecifier(newSpec);
1222 
1223  // Fills in slice information.
1224  for (auto [idx, offset, size, stride] : llvm::enumerate(
1225  op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1226  Dimension dim = idx;
1227 
1228  Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1229  Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1230  Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1231  // TODO: We could probably only set dynamic value here. But it would
1232  // requires us to fill the hole when casting a static slice to dynamic
1233  // slice.
1234  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1235  dim, offsetV);
1236 
1237  // FIXME: we need to distinguish level sizes and dimension size for slices
1238  // here. Maybe we should store slice level sizes in a different array
1239  // instead of reusing it.
1240  assert(srcEnc.isIdentity());
1241  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1242  sizeV);
1243  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1244  dim, strideV);
1245  }
1246 
1247  // NOTE: we can not generate tuples directly from descriptor here, as the
1248  // descriptor is holding the original type, yet we want the slice type
1249  // here (they shared every memref but with an updated specifier).
1250  rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
1251  desc.getFields()));
1252  return success();
1253  }
1254 };
1255 
1256 /// Sparse codegen rule for number of entries operator.
1257 class SparseNumberOfEntriesConverter
1258  : public OpConversionPattern<NumberOfEntriesOp> {
1259 public:
1262  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1263  ConversionPatternRewriter &rewriter) const override {
1264  // Query memSizes for the actually stored values.
1265  // FIXME: the nse value computed in this way might be wrong when there is
1266  // any "loose_compressed" level.
1267  rewriter.replaceOp(
1268  op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
1269  return success();
1270  }
1271 };
1272 
1273 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1276  matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1277  ConversionPatternRewriter &rewriter) const override {
1278  Location loc = op.getLoc();
1279  const auto stt = getSparseTensorType(op.getResult());
1280 
1281  SmallVector<Value> fields;
1282 
1284  stt,
1285  [&rewriter, &fields, &op, &stt,
1286  loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1287  Level /*lvl*/, LevelType lt) -> bool {
1288  assert(fields.size() == fIdx);
1289  if (fKind == SparseTensorFieldKind::StorageSpec) {
1290  fields.push_back(
1291  SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1292  } else {
1293  // Else simply takes the inputs.
1294  Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1295  ? op.getValues()
1296  : op.getLevels()[fIdx];
1297 
1298  TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1299  if (mem.getType().getRank() > 1) {
1300  // Flattens the buffer to rank 1.
1301  auto reassoc = getReassociationForFlattening(mem.getType());
1302  mem = rewriter.create<memref::CastOp>(
1303  loc, fType,
1304  rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1305  } else {
1306  mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1307  }
1308  fields.push_back(mem);
1309  }
1310  return true;
1311  });
1312 
1313  MutSparseTensorDescriptor desc(stt, fields);
1314  Value c0 = constantIndex(rewriter, loc, 0);
1315  Value c1 = constantIndex(rewriter, loc, 1);
1316  Value c2 = constantIndex(rewriter, loc, 2);
1317  Value posBack = c0; // index to the last value in the position array
1318  Value memSize = c1; // memory size for current array
1319 
1320  Level trailCOOStart = getCOOStart(stt.getEncoding());
1321  Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1322  // Sets up SparseTensorSpecifier.
1323  for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1324  assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1325 
1326  // FIXME: dim/lvl confusion!
1327  // Sets up the level size.
1328  auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
1329  desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1330  // We use a single AOS array to store the trailing COO, so there is only
1331  // one memory size to set for the entire COO section.
1332  if (lvl > trailCOOStart)
1333  continue;
1334 
1335  // Sets up the memory size by reading the last value in position array.
1336  LevelType lt = stt.getLvlType(lvl);
1337  // Simply forwards the position index when this is a dense level.
1338  if (isDenseLT(lt)) {
1339  memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1340  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1341  continue;
1342  }
1343 
1344  if (isWithPosLT(lt)) {
1345  assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
1346  if (isLooseCompressedLT(lt)) {
1347  memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1348  posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1349  } else {
1350  assert(isCompressedLT(lt));
1351  posBack = memSize;
1352  memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1353  }
1354  desc.setPosMemSize(rewriter, loc, lvl, memSize);
1355  // The last value in position array is the memory size for next level.
1356  memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
1357  posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1358  }
1359  assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
1360  // FIXME: This seems to be unnecessarily complex, can we simplify it?
1361  if (lvl == trailCOOStart) {
1362  Value cooSz = rewriter.create<arith::MulIOp>(
1363  loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1364  desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1365  } else {
1366  desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1367  }
1368  }
1369  desc.setValMemSize(rewriter, loc, memSize);
1370 
1371  rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
1372  return success();
1373  }
1374 };
1375 
1376 struct SparseDisassembleOpConverter
1377  : public OpConversionPattern<DisassembleOp> {
1379  SparseDisassembleOpConverter(TypeConverter &typeConverter,
1380  MLIRContext *context)
1381  : OpConversionPattern(typeConverter, context) {}
1382 
1384  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1385  ConversionPatternRewriter &rewriter) const override {
1386  auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1387  Location loc = op.getLoc();
1388  SmallVector<Value> retMem;
1389  SmallVector<Value> retLen;
1390  desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem,
1391  &retLen](FieldIndex fid,
1392  SparseTensorFieldKind fKind,
1393  Level lvl, LevelType lt) -> bool {
1394  if (fKind == SparseTensorFieldKind::StorageSpec)
1395  return true;
1397  Value sz, src;
1399  if (fKind == SparseTensorFieldKind::ValMemRef) {
1400  sz = desc.getValMemSize(rewriter, loc);
1401  src = desc.getValMemRef();
1402  dst = genToMemref(rewriter, loc, op.getOutValues());
1403  // Values is the last field in descriptor, but it is the first
1404  // operand in unpack operation.
1405  // TODO: maybe change unpack/pack operation instead to be
1406  // consistent.
1407  retMem.insert(retMem.begin(), dst);
1408  Type valLenTp = op.getValLen().getType();
1409  retLen.insert(retLen.begin(),
1410  genScalarToTensor(rewriter, loc, sz, valLenTp));
1411  } else {
1412  assert(fKind == SparseTensorFieldKind::PosMemRef ||
1413  fKind == SparseTensorFieldKind::CrdMemRef);
1414 
1415  sz = fKind == SparseTensorFieldKind::PosMemRef
1416  ? desc.getPosMemSize(rewriter, loc, lvl)
1417  : desc.getCrdMemSize(rewriter, loc, lvl);
1418  src = desc.getMemRefField(fid);
1419  dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1420  retMem.push_back(dst);
1421  // Retrieves the corresponding level length type.
1422  Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1423  retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1424  }
1425  Value flatOut = dst;
1426  if (dst.getType().getRank() != 1) {
1427  auto reassoc = getReassociationForFlattening(dst.getType());
1428  flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1429  }
1430  Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1431  Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1432  rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1433  return true;
1434  });
1435 
1436  // Converts MemRefs back to Tensors.
1437  SmallVector<Value> retValues = llvm::to_vector(
1438  llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1439  return rewriter.create<bufferization::ToTensorOp>(loc, v);
1440  }));
1441  // Appends the actual memory length used in each buffer returned.
1442  retValues.append(retLen.begin(), retLen.end());
1443  rewriter.replaceOp(op, retValues);
1444  return success();
1445  }
1446 };
1447 
1448 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1451  matchAndRewrite(NewOp op, OpAdaptor adaptor,
1452  ConversionPatternRewriter &rewriter) const override {
1453  Location loc = op.getLoc();
1454  const auto dstTp = getSparseTensorType(op.getResult());
1455  // Creating COO with NewOp is handled by direct IR codegen. All other cases
1456  // are handled by rewriting.
1457  if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
1458  return failure();
1459 
1460  // Implement as follows:
1461  // %reader = @createCheckedSparseTensorReader(%filename)
1462  // %nse = @getSparseTensorNSE(%reader)
1463  // %coo = bufferization.alloc_tensor an ordered COO with
1464  // dst dim ordering, size_hint = %nse
1465  // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1466  // %values = sparse_tensor.values(%coo)
1467  // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1468  // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1469  // update storage specifier
1470  // @delSparseTensorReader(%reader)
1471  SmallVector<Value> dimSizesValues;
1472  Value dimSizesBuffer;
1473  Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1474  dimSizesValues, dimSizesBuffer);
1475 
1476  // Get the number of stored entries.
1477  const Type indexTp = rewriter.getIndexType();
1478  Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1479  {indexTp}, {reader}, EmitCInterface::Off)
1480  .getResult(0);
1481 
1482  // Construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1483  SmallVector<Value> lvlSizesValues;
1484  Value dim2lvlBuffer;
1485  Value lvl2dimBuffer;
1486  genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1487  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1488 
1489  // Construct allocation for each field.
1490  Value sizeHint = nse;
1491  SmallVector<Value> fields;
1492  createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint,
1493  lvlSizesValues, fields);
1494 
1495  // Read the COO tensor data.
1496  MutSparseTensorDescriptor desc(dstTp, fields);
1497  Value xs = desc.getAOSMemRef();
1498  Value ys = desc.getValMemRef();
1499  const Type boolTp = rewriter.getIntegerType(1);
1500  const Type elemTp = dstTp.getElementType();
1501  const Type crdTp = dstTp.getCrdType();
1502  SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1504  primaryTypeFunctionSuffix(elemTp)};
1505  Value isSorted =
1506  createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1507  {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1508  EmitCInterface::On)
1509  .getResult(0);
1510 
1511  // If the destination tensor is a sorted COO, we need to sort the COO tensor
1512  // data if the input elements aren't sorted yet.
1513  const Level lvlRank = dstTp.getLvlRank();
1514  if (dstTp.isOrderedLvl(lvlRank - 1)) {
1515  Value kFalse = constantI1(rewriter, loc, false);
1516  Value notSorted = rewriter.create<arith::CmpIOp>(
1517  loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1518  scf::IfOp ifOp =
1519  rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1520  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1521  auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1522  rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1523  rewriter.getIndexAttr(0),
1524  SparseTensorSortKind::HybridQuickSort);
1525  rewriter.setInsertionPointAfter(ifOp);
1526  }
1527 
1528  // Set PosMemRef0[1] = nse.
1529  const Value c1 = constantIndex(rewriter, loc, 1);
1530  const Value posMemref0 = desc.getPosMemRef(0);
1531  const Type posTp = dstTp.getPosType();
1532  const Value posNse = genCast(rewriter, loc, nse, posTp);
1533  rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1534 
1535  // Update storage specifier.
1536  Value coordinatesSize = rewriter.create<arith::MulIOp>(
1537  loc, nse, constantIndex(rewriter, loc, lvlRank));
1538  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1539  coordinatesSize);
1540  desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1541  std::nullopt, nse);
1542 
1543  // Release the sparse tensor reader.
1544  createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1545  EmitCInterface::Off);
1546 
1547  // Replace operation with resulting memrefs.
1548  rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
1549  return success();
1550  }
1551 };
1552 
1553 } // namespace
1554 
1555 //===----------------------------------------------------------------------===//
1556 // Public method for populating conversion rules.
1557 //===----------------------------------------------------------------------===//
1558 
1559 /// Populates the given patterns list with conversion rules required for
1560 /// the sparsification of linear algebra operations.
1562  TypeConverter &typeConverter, RewritePatternSet &patterns,
1563  bool createSparseDeallocs, bool enableBufferInitialization) {
1564  patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
1565  SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1566  SparseCastConverter, SparseExtractSliceConverter,
1567  SparseTensorLoadConverter, SparseExpandConverter,
1568  SparseCompressConverter, SparseInsertConverter,
1569  SparseReorderCOOConverter, SparseReMapConverter,
1570  SparseSliceGetterOpConverter<ToSliceOffsetOp,
1571  StorageSpecifierKind::DimOffset>,
1572  SparseSliceGetterOpConverter<ToSliceStrideOp,
1573  StorageSpecifierKind::DimStride>,
1574  SparseToPositionsConverter, SparseToCoordinatesConverter,
1575  SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1576  SparseConvertConverter, SparseNewConverter,
1577  SparseNumberOfEntriesConverter>(typeConverter,
1578  patterns.getContext());
1579  patterns.add<SparseTensorDeallocConverter>(
1580  typeConverter, patterns.getContext(), createSparseDeallocs);
1581  patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1582  typeConverter, patterns.getContext(), enableBufferInitialization);
1583 }
static void flattenOperands(ValueRange operands, SmallVectorImpl< Value > &flattened)
Flattens a list of operands that may contain sparse tensors.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static ReassociationIndices getReassociationForFlattening(ShapedType srcTp)
Creates the reassociation array.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:312
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
A helper class to simplify lowering operations with/without function calls.
Definition: CodegenUtils.h:78
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
Returns the encoding (or the null-attribute for dense-tensors).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
constexpr const char * toMLIRString(LevelType lt)
Returns string representation of the given dimension level type.
Definition: Enums.h:202
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:339
Level getCOOStart(SparseTensorEncodingAttr enc)
Returns the starting level for a trailing COO region that spans at least two levels.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:350
constexpr bool isWithPosLT(LevelType lt)
Check if the LevelType needs positions array.
Definition: Enums.h:283
constexpr bool isLooseCompressedLT(LevelType lt)
Check if the LevelType is loose compressed (regardless of properties).
Definition: Enums.h:271
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:386
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
constexpr bool isWithCrdLT(LevelType lt)
Check if the LevelType needs coordinates array.
Definition: Enums.h:288
constexpr bool is2OutOf4LT(LevelType lt)
Check if the LevelType is 2OutOf4 (regardless of properties).
Definition: Enums.h:277
constexpr bool isDenseLT(LevelType lt)
Check if the LevelType is dense (regardless of properties).
Definition: Enums.h:253
bool isUniqueCOOType(Type tp)
Returns true iff the given type is a COO type where the last level is unique.
Operation * getTop(Operation *op)
Scans to top of generated loop.
constexpr bool isSingletonLT(LevelType lt)
Check if the LevelType is singleton (regardless of properties).
Definition: Enums.h:265
LevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:168
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s)
Generates a pointer/index load from the sparse storage scheme.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor)
constexpr bool isCompressedLT(LevelType lt)
Check if the LevelType is compressed (regardless of properties).
Definition: Enums.h:259
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields)
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor)
Generates code to retrieve the values size for the sparse tensor.
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26