MLIR  19.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
31 
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
34 
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
37 
38 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
39 // well.
40 namespace mlir::sparse_tensor {
41 llvm::hash_code hash_value(LevelType lt) {
42  return llvm::hash_value(static_cast<uint64_t>(lt));
43 }
44 } // namespace mlir::sparse_tensor
45 
46 //===----------------------------------------------------------------------===//
47 // Local Convenience Methods.
48 //===----------------------------------------------------------------------===//
49 
50 static constexpr bool acceptBitWidth(unsigned bitWidth) {
51  switch (bitWidth) {
52  case 0:
53  case 8:
54  case 16:
55  case 32:
56  case 64:
57  return true;
58  default:
59  return false;
60  }
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // SparseTensorDialect StorageLayout.
65 //===----------------------------------------------------------------------===//
66 
67 static constexpr Level kInvalidLevel = -1u;
68 static constexpr Level kInvalidFieldIndex = -1u;
69 static constexpr FieldIndex kDataFieldStartingIdx = 0;
70 
73  LevelType)>
74  callback) const {
75  const auto lvlTypes = enc.getLvlTypes();
76  const Level lvlRank = enc.getLvlRank();
79 
80  ArrayRef cooSegsRef = cooSegs;
81  // Per-level storage.
82  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
83  const auto lt = lvlTypes[l];
84  if (isWithPosLT(lt)) {
85  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
86  return;
87  }
88  if (isWithCrdLT(lt)) {
89  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
90  return;
91  }
92  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
93  if (!cooSegsRef.front().isSoA) {
94  // AoS COO, all singletons are fused into one memrefs. Skips the entire
95  // COO segement.
96  l = cooSegsRef.front().lvlRange.second;
97  } else {
98  // SoA COO, each singleton level has one memref.
99  l++;
100  }
101  // Expire handled COO segment.
102  cooSegsRef = cooSegsRef.drop_front();
103  } else {
104  // Non COO levels.
105  l++;
106  }
107  }
108  // The values array.
109  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
111  return;
112  // Put metadata at the end.
113  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
115  return;
116 }
117 
119  SparseTensorType stt,
121  LevelType)>
122  callback) {
123  assert(stt.hasEncoding());
124  // Construct the basic types.
125  const Type crdType = stt.getCrdType();
126  const Type posType = stt.getPosType();
127  const Type eltType = stt.getElementType();
128 
129  const Type specType = StorageSpecifierType::get(stt.getEncoding());
130  // memref<? x pos> positions
131  const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
132  // memref<? x crd> coordinates
133  const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
134  // memref<? x eltType> values
135  const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
136 
137  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
138  callback](FieldIndex fieldIdx,
139  SparseTensorFieldKind fieldKind,
140  Level lvl, LevelType lt) -> bool {
141  switch (fieldKind) {
143  return callback(specType, fieldIdx, fieldKind, lvl, lt);
145  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
147  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
149  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
150  };
151  llvm_unreachable("unrecognized field kind");
152  });
153 }
154 
155 unsigned StorageLayout::getNumFields() const {
156  unsigned numFields = 0;
158  LevelType) -> bool {
159  numFields++;
160  return true;
161  });
162  return numFields;
163 }
164 
166  unsigned numFields = 0; // one value memref
168  LevelType) -> bool {
169  if (fidx >= kDataFieldStartingIdx)
170  numFields++;
171  return true;
172  });
173  numFields -= 1; // the last field is StorageSpecifier
174  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
175  return numFields;
176 }
177 
178 std::pair<FieldIndex, unsigned>
180  std::optional<Level> lvl) const {
181  FieldIndex fieldIdx = kInvalidFieldIndex;
182  unsigned stride = 1;
183  if (kind == SparseTensorFieldKind::CrdMemRef) {
184  assert(lvl.has_value());
185  const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
186  const Level lvlRank = enc.getLvlRank();
187  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
188  lvl = cooStart;
189  stride = lvlRank - cooStart;
190  }
191  }
192  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
193  SparseTensorFieldKind fKind, Level fLvl,
194  LevelType lt) -> bool {
195  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
196  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
197  fieldIdx = fIdx;
198  // Returns false to break the iteration.
199  return false;
200  }
201  return true;
202  });
203  assert(fieldIdx != kInvalidFieldIndex);
204  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // SparseTensorDialect Attribute Methods.
209 //===----------------------------------------------------------------------===//
210 
211 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
212  return isDynamic(v) ? std::nullopt
213  : std::make_optional(static_cast<uint64_t>(v));
214 }
215 
216 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
217  return getStatic(getOffset());
218 }
219 
220 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
221  return getStatic(getStride());
222 }
223 
224 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
225  return getStatic(getSize());
226 }
227 
228 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
229  return isDynamic(getOffset()) && isDynamic(getStride()) &&
230  isDynamic(getSize());
231 }
232 
233 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
234  return isDynamic(v) ? "?" : std::to_string(v);
235 }
236 
237 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
238  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
239  os << '(';
240  os << getStaticString(getOffset());
241  os << ", ";
242  os << getStaticString(getSize());
243  os << ", ";
244  os << getStaticString(getStride());
245  os << ')';
246 }
247 
248 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
249  print(printer.getStream());
250 }
251 
252 static ParseResult parseOptionalStaticSlice(int64_t &result,
253  AsmParser &parser) {
254  auto parseResult = parser.parseOptionalInteger(result);
255  if (parseResult.has_value()) {
256  if (parseResult.value().succeeded() && result < 0) {
257  parser.emitError(
258  parser.getCurrentLocation(),
259  "expect positive value or ? for slice offset/size/stride");
260  return failure();
261  }
262  return parseResult.value();
263  }
264 
265  // Else, and '?' which represented dynamic slice
266  result = SparseTensorDimSliceAttr::kDynamic;
267  return parser.parseQuestion();
268 }
269 
271  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
272 
273  if (failed(parser.parseLParen()) ||
274  failed(parseOptionalStaticSlice(offset, parser)) ||
275  failed(parser.parseComma()) ||
276  failed(parseOptionalStaticSlice(size, parser)) ||
277  failed(parser.parseComma()) ||
278  failed(parseOptionalStaticSlice(stride, parser)) ||
279  failed(parser.parseRParen()))
280  return {};
281 
282  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
283  offset, size, stride);
284 }
285 
288  int64_t offset, int64_t size, int64_t stride) {
289  if (!isDynamic(offset) && offset < 0)
290  return emitError() << "expect non-negative value or ? for slice offset";
291  if (!isDynamic(size) && size <= 0)
292  return emitError() << "expect positive value or ? for slice size";
293  if (!isDynamic(stride) && stride <= 0)
294  return emitError() << "expect positive value or ? for slice stride";
295  return success();
296 }
297 
298 SparseTensorEncodingAttr
299 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
300  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
301  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
302  AffineMap(), getPosWidth(),
303  getCrdWidth());
304 }
305 
306 SparseTensorEncodingAttr
307 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
308  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
309 }
310 
311 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
312  return withDimToLvl(AffineMap());
313 }
314 
315 SparseTensorEncodingAttr
316 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
317  unsigned crdWidth) const {
318  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
319  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
320  getDimToLvl(), getLvlToDim(), posWidth,
321  crdWidth);
322 }
323 
324 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
325  return withBitWidths(0, 0);
326 }
327 
328 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
329  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
330  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
331  getDimToLvl(), getLvlToDim(),
332  getPosWidth(), getCrdWidth(), dimSlices);
333 }
334 
335 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
336  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
337 }
338 
340  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
341 }
342 
343 bool SparseTensorEncodingAttr::isAllOrdered() const {
344  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
345 }
346 
347 bool SparseTensorEncodingAttr::isIdentity() const {
348  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
349 }
350 
352  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
353 }
354 
355 Dimension SparseTensorEncodingAttr::getDimRank() const {
356  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
357  const auto dimToLvl = getDimToLvl();
358  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
359 }
360 
361 Level SparseTensorEncodingAttr::getLvlRank() const {
362  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
363  return getLvlTypes().size();
364 }
365 
366 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
367  if (!getImpl())
368  return LevelFormat::Dense;
369  assert(l < getLvlRank() && "Level is out of bounds");
370  return getLvlTypes()[l];
371 }
372 
373 bool SparseTensorEncodingAttr::isSlice() const {
374  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
375  return !getDimSlices().empty();
376 }
377 
378 SparseTensorDimSliceAttr
379 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
380  assert(isSlice() && "Is not a slice");
381  const auto dimSlices = getDimSlices();
382  assert(dim < dimSlices.size() && "Dimension is out of bounds");
383  return dimSlices[dim];
384 }
385 
386 std::optional<uint64_t>
387 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
388  return getDimSlice(dim).getStaticOffset();
389 }
390 
391 std::optional<uint64_t>
392 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
393  return getDimSlice(dim).getStaticStride();
394 }
395 
396 std::optional<uint64_t>
397 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
398  return getStaticDimSliceOffset(toDim(*this, lvl));
399 }
400 
401 std::optional<uint64_t>
402 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
403  return getStaticDimSliceStride(toDim(*this, lvl));
404 }
405 
407 SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
408  CrdTransDirectionKind dir) const {
409  if (isIdentity())
410  return SmallVector<int64_t>(srcShape);
411 
413  unsigned rank =
414  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
415  ret.reserve(rank);
416 
417  if (isPermutation()) {
418  for (unsigned r = 0; r < rank; r++) {
419  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
420  : toLvl(*this, r);
421  ret.push_back(srcShape[trans]);
422  }
423  return ret;
424  }
425 
426  // Handle non-permutation maps.
427  AffineMap transMap =
428  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
429 
431  dimRep.reserve(srcShape.size());
432  for (int64_t sz : srcShape) {
433  if (!ShapedType::isDynamic(sz)) {
434  // Push back the max coordinate for the given dimension/level size.
435  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
436  } else {
437  // A dynamic size, use a AffineDimExpr to symbolize the value.
438  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
439  }
440  };
441 
442  for (AffineExpr exp : transMap.getResults()) {
443  // Do constant propagation on the affine map.
444  AffineExpr evalExp =
445  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
446  // use llvm namespace here to avoid ambiguity
447  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
448  ret.push_back(c.getValue() + 1);
449  } else {
450  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
451  mod && mod.getKind() == AffineExprKind::Mod) {
452  // We can still infer a static bound for expressions in form
453  // "d % constant" since d % constant \in [0, constant).
454  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
455  ret.push_back(bound.getValue());
456  continue;
457  }
458  }
459  ret.push_back(ShapedType::kDynamic);
460  }
461  }
462  assert(ret.size() == rank);
463  return ret;
464 }
465 
467 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
468  ValueRange crds,
469  CrdTransDirectionKind dir) const {
470  if (!getImpl())
471  return crds;
472 
473  SmallVector<Type> retType(
474  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
475  builder.getIndexType());
476  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
477  return transOp.getOutCrds();
478 }
479 
481  // Open "<{" part.
482  if (failed(parser.parseLess()))
483  return {};
484  if (failed(parser.parseLBrace()))
485  return {};
486 
487  // Process the data from the parsed dictionary value into struct-like data.
488  SmallVector<LevelType> lvlTypes;
490  AffineMap dimToLvl = {};
491  AffineMap lvlToDim = {};
492  unsigned posWidth = 0;
493  unsigned crdWidth = 0;
494  StringRef attrName;
495  SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
496  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
497  // Detect admissible keyword.
498  auto *it = find(keys, attrName);
499  if (it == keys.end()) {
500  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
501  return {};
502  }
503  unsigned keyWordIndex = it - keys.begin();
504  // Consume the `=` after keys
505  if (failed(parser.parseEqual()))
506  return {};
507  // Dispatch on keyword.
508  switch (keyWordIndex) {
509  case 0: { // map
510  ir_detail::DimLvlMapParser cParser(parser);
511  auto res = cParser.parseDimLvlMap();
512  if (failed(res))
513  return {};
514  const auto &dlm = *res;
515 
516  const Level lvlRank = dlm.getLvlRank();
517  for (Level lvl = 0; lvl < lvlRank; lvl++)
518  lvlTypes.push_back(dlm.getLvlType(lvl));
519 
520  const Dimension dimRank = dlm.getDimRank();
521  for (Dimension dim = 0; dim < dimRank; dim++)
522  dimSlices.push_back(dlm.getDimSlice(dim));
523  // NOTE: the old syntax requires an all-or-nothing approach to
524  // `dimSlices`; therefore, if any slice actually exists then we need
525  // to convert null-DSA into default/nop DSA.
526  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
527  return static_cast<bool>(slice.getImpl());
528  };
529  if (llvm::any_of(dimSlices, isDefined)) {
530  const auto defaultSlice =
532  for (Dimension dim = 0; dim < dimRank; dim++)
533  if (!isDefined(dimSlices[dim]))
534  dimSlices[dim] = defaultSlice;
535  } else {
536  dimSlices.clear();
537  }
538 
539  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
540  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
541  break;
542  }
543  case 1: { // posWidth
544  Attribute attr;
545  if (failed(parser.parseAttribute(attr)))
546  return {};
547  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
548  if (!intAttr) {
549  parser.emitError(parser.getNameLoc(),
550  "expected an integral position bitwidth");
551  return {};
552  }
553  posWidth = intAttr.getInt();
554  break;
555  }
556  case 2: { // crdWidth
557  Attribute attr;
558  if (failed(parser.parseAttribute(attr)))
559  return {};
560  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
561  if (!intAttr) {
562  parser.emitError(parser.getNameLoc(),
563  "expected an integral index bitwidth");
564  return {};
565  }
566  crdWidth = intAttr.getInt();
567  break;
568  }
569  } // switch
570  // Only last item can omit the comma.
571  if (parser.parseOptionalComma().failed())
572  break;
573  }
574 
575  // Close "}>" part.
576  if (failed(parser.parseRBrace()))
577  return {};
578  if (failed(parser.parseGreater()))
579  return {};
580 
581  // Construct struct-like storage for attribute.
582  if (!lvlToDim || lvlToDim.isEmpty()) {
583  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
584  }
585  return parser.getChecked<SparseTensorEncodingAttr>(
586  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
587  dimSlices);
588 }
589 
590 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
591  auto map = static_cast<AffineMap>(getDimToLvl());
592  // Empty affine map indicates identity map
593  if (!map)
594  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
595  printer << "<{ map = ";
596  printSymbols(map, printer);
597  printer << '(';
598  printDimensions(map, printer, getDimSlices());
599  printer << ") -> (";
600  printLevels(map, printer, getLvlTypes());
601  printer << ')';
602  // Print remaining members only for non-default values.
603  if (getPosWidth())
604  printer << ", posWidth = " << getPosWidth();
605  if (getCrdWidth())
606  printer << ", crdWidth = " << getCrdWidth();
607  printer << " }>";
608 }
609 
610 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
611  AsmPrinter &printer) const {
612  if (map.getNumSymbols() == 0)
613  return;
614  printer << '[';
615  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
616  printer << 's' << i << ", ";
617  if (map.getNumSymbols() >= 1)
618  printer << 's' << map.getNumSymbols() - 1;
619  printer << ']';
620 }
621 
622 void SparseTensorEncodingAttr::printDimensions(
623  AffineMap &map, AsmPrinter &printer,
624  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
625  if (!dimSlices.empty()) {
626  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
627  printer << 'd' << i << " : " << dimSlices[i] << ", ";
628  if (map.getNumDims() >= 1) {
629  printer << 'd' << map.getNumDims() - 1 << " : "
630  << dimSlices[map.getNumDims() - 1];
631  }
632  } else {
633  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
634  printer << 'd' << i << ", ";
635  if (map.getNumDims() >= 1)
636  printer << 'd' << map.getNumDims() - 1;
637  }
638 }
639 
640 std::string getNOutOfMString(LevelType lt) {
641  if (isNOutOfMLT(lt)) {
642  unsigned n = getN(lt);
643  unsigned m = getM(lt);
644  auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
645  return output;
646  }
647  return "";
648 }
649 
650 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
651  ArrayRef<LevelType> lvlTypes) const {
652  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
653  map.getResult(i).print(printer.getStream());
654  printer << " : " << toMLIRString(lvlTypes[i])
655  << getNOutOfMString(lvlTypes[i]) << ", ";
656  }
657  if (map.getNumResults() >= 1) {
658  auto lastIndex = map.getNumResults() - 1;
659  map.getResult(lastIndex).print(printer.getStream());
660  printer << " : " << toMLIRString(lvlTypes[lastIndex])
661  << getNOutOfMString(lvlTypes[lastIndex]);
662  }
663 }
664 
667  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
668  unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
669  if (!acceptBitWidth(posWidth))
670  return emitError() << "unexpected position bitwidth: " << posWidth;
671  if (!acceptBitWidth(crdWidth))
672  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
673  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
674  it != std::end(lvlTypes)) {
675  if (it == lvlTypes.begin() ||
676  (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
677  return emitError() << "expected compressed or loose_compressed level "
678  "before singleton level";
679  if (!std::all_of(it, lvlTypes.end(),
680  [](LevelType i) { return isSingletonLT(i); }))
681  return emitError() << "expected all singleton lvlTypes "
682  "following a singleton level";
683  // We can potentially support mixed SoA/AoS singleton levels.
684  if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
685  return it->isa<LevelPropNonDefault::SoA>() ==
686  i.isa<LevelPropNonDefault::SoA>();
687  })) {
688  return emitError() << "expected all singleton lvlTypes stored in the "
689  "same memory layout (SoA vs AoS).";
690  }
691  }
692 
693  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
694  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
695  return emitError() << "Batch lvlType can only be leading levels.";
696 
697  // SoA property can only be applied on singleton level.
698  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
699  return lt.isa<LevelPropNonDefault::SoA>();
700  });
701  if (llvm::any_of(soaLvls, [](LevelType lt) {
702  return !lt.isa<LevelFormat::Singleton>();
703  })) {
704  return emitError() << "SoA is only applicable to singleton lvlTypes.";
705  }
706 
707  // TODO: audit formats that actually are supported by backend.
708  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
709  it != std::end(lvlTypes)) {
710  if (it != lvlTypes.end() - 1)
711  return emitError() << "expected n_out_of_m to be the last level type";
712  if (!std::all_of(lvlTypes.begin(), it,
713  [](LevelType i) { return isDenseLT(i); }))
714  return emitError() << "expected all dense lvlTypes "
715  "before a n_out_of_m level";
716  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
717  if (!isBlockSparsity(dimToLvl)) {
718  return emitError()
719  << "expected 1xm block structure for n_out_of_m level";
720  }
721  auto sizes = getBlockSize(dimToLvl);
722  unsigned coefficient = 0;
723  for (const auto &elem : sizes) {
724  if (elem != 0) {
725  if (elem != coefficient && coefficient != 0) {
726  return emitError() << "expected only one blocked level "
727  "with the same coefficients";
728  }
729  coefficient = elem;
730  }
731  }
732  if (coefficient != getM(*it)) {
733  return emitError() << "expected coeffiencts of Affine expressions "
734  "to be equal to m of n_out_of_m level";
735  }
736  }
737  }
738  // Before we can check that the level-rank is consistent/coherent
739  // across all fields, we need to define it. The source-of-truth for
740  // the `getLvlRank` method is the length of the level-types array,
741  // since it must always be provided and have full rank; therefore we
742  // use that same source-of-truth here.
743  const Level lvlRank = lvlTypes.size();
744  if (lvlRank == 0)
745  return emitError() << "expected a non-empty array for lvlTypes";
746  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
747  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
748  if (dimToLvl) {
749  if (dimToLvl.getNumResults() != lvlRank)
750  return emitError()
751  << "level-rank mismatch between dimToLvl and lvlTypes: "
752  << dimToLvl.getNumResults() << " != " << lvlRank;
753  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
754  // Symbols can't be inferred but are acceptable.
755  if (!inferRes && dimToLvl.getNumSymbols() == 0)
756  return emitError() << "failed to infer lvlToDim from dimToLvl";
757  if (lvlToDim && (inferRes != lvlToDim))
758  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
759  if (dimRank > lvlRank)
760  return emitError() << "unexpected dimToLvl mapping from " << dimRank
761  << " to " << lvlRank;
762  }
763  if (!dimSlices.empty()) {
764  if (dimSlices.size() != dimRank)
765  return emitError()
766  << "dimension-rank mismatch between dimSlices and dimToLvl: "
767  << dimSlices.size() << " != " << dimRank;
768  // Compiler support for `dimSlices` currently requires that the two
769  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
770  if (dimRank != lvlRank)
771  return emitError()
772  << "dimSlices expected dimension-rank to match level-rank: "
773  << dimRank << " != " << lvlRank;
774  }
775  return success();
776 }
777 
778 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
779  ArrayRef<Size> dimShape, Type elementType,
781  // Check structural integrity. In particular, this ensures that the
782  // level-rank is coherent across all the fields.
783  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
784  getPosWidth(), getCrdWidth(), getDimSlices())))
785  return failure();
786  // Check integrity with tensor type specifics. In particular, we
787  // need only check that the dimension-rank of the tensor agrees with
788  // the dimension-rank of the encoding.
789  const Dimension dimRank = dimShape.size();
790  if (dimRank == 0)
791  return emitError() << "expected non-scalar sparse tensor";
792  if (getDimRank() != dimRank)
793  return emitError()
794  << "dimension-rank mismatch between encoding and tensor shape: "
795  << getDimRank() << " != " << dimRank;
796  return success();
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // SparseTensorType Methods.
801 //===----------------------------------------------------------------------===//
802 
804  bool isUnique) const {
805  if (!hasEncoding())
806  return false;
807  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
808  return false;
809  for (Level l = startLvl + 1; l < lvlRank; ++l)
810  if (!isSingletonLvl(l))
811  return false;
812  // If isUnique is true, then make sure that the last level is unique,
813  // that is, when lvlRank == 1, the only compressed level is unique,
814  // and when lvlRank > 1, the last singleton is unique.
815  return !isUnique || isUniqueLvl(lvlRank - 1);
816 }
817 
819  SmallVector<COOSegment> coo = getCOOSegments();
820  assert(coo.size() == 1 || coo.empty());
821  if (!coo.empty() && coo.front().isAoS()) {
822  return coo.front().lvlRange.first;
823  }
824  return lvlRank;
825 }
826 
830  if (!hasEncoding() || lvlRank <= 1)
831  return ret;
832 
833  ArrayRef<LevelType> lts = getLvlTypes();
834  Level l = 0;
835  while (l < lvlRank) {
836  auto lt = lts[l];
838  auto cur = lts.begin() + l;
839  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
840  return !lt.isa<LevelFormat::Singleton>();
841  });
842  unsigned cooLen = std::distance(cur, end);
843  if (cooLen > 1) {
844  // To support mixed SoA/AoS COO, we should break the segment when the
845  // storage scheme changes, for now we faithfully assume that all
846  // consecutive singleton levels have the same storage format as verified
847  // STEA.
848  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
849  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
850  }
851  l += cooLen;
852  } else {
853  l++;
854  }
855  }
856  return ret;
857 }
858 
859 RankedTensorType
861  SmallVector<LevelType> lvlTypes;
862  lvlTypes.reserve(lvlRank);
863  // A non-unique compressed level at beginning (unless this is
864  // also the last level, then it is unique).
865  lvlTypes.push_back(
866  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
867  if (lvlRank > 1) {
868  // Followed by n-2 non-unique singleton levels.
869  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
870  *buildLevelType(LevelFormat::Singleton, ordered, false));
871  // Ends by a unique singleton level.
872  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
873  }
874  auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
875  getDimToLvl(), getLvlToDim(),
876  getPosWidth(), getCrdWidth());
877  return RankedTensorType::get(getDimShape(), getElementType(), enc);
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // Convenience Methods.
882 //===----------------------------------------------------------------------===//
883 
884 SparseTensorEncodingAttr
886  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
887  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
888  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
889  return mdtp.getEncoding();
890  return nullptr;
891 }
892 
894  MLIRContext *context) {
895  auto map = static_cast<AffineMap>(dimToLvl);
896  AffineMap lvlToDim;
897  // Return an empty lvlToDim when inference is not successful.
898  if (!map || map.getNumSymbols() != 0) {
899  lvlToDim = AffineMap();
900  } else if (map.isPermutation()) {
901  lvlToDim = inversePermutation(map);
902  } else if (isBlockSparsity(map)) {
903  lvlToDim = inverseBlockSparsity(map, context);
904  }
905  return lvlToDim;
906 }
907 
909  MLIRContext *context) {
910  SmallVector<AffineExpr> lvlExprs;
911  auto numLvls = dimToLvl.getNumResults();
912  lvlExprs.reserve(numLvls);
913  // lvlExprComponents stores information of the floordiv and mod operations
914  // applied to the same dimension, so as to build the lvlToDim map.
915  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
916  for (unsigned i = 0, n = numLvls; i < n; i++) {
917  auto result = dimToLvl.getResult(i);
918  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
919  if (result.getKind() == AffineExprKind::FloorDiv) {
920  // Position of the dimension in dimToLvl.
921  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
922  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
923  "expected only one floordiv for each dimension");
924  SmallVector<AffineExpr, 3> components;
925  // Level variable for floordiv.
926  components.push_back(getAffineDimExpr(i, context));
927  // Multiplier.
928  components.push_back(binOp.getRHS());
929  // Map key is the position of the dimension.
930  lvlExprComponents[pos] = components;
931  } else if (result.getKind() == AffineExprKind::Mod) {
932  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
933  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
934  "expected floordiv before mod");
935  // Add level variable for mod to the same vector
936  // of the corresponding floordiv.
937  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
938  } else {
939  assert(false && "expected floordiv or mod");
940  }
941  } else {
942  lvlExprs.push_back(getAffineDimExpr(i, context));
943  }
944  }
945  // Build lvlExprs from lvlExprComponents.
946  // For example, for il = i floordiv 2 and ii = i mod 2, the components
947  // would be [il, 2, ii]. It could be used to build the AffineExpr
948  // i = il * 2 + ii in lvlToDim.
949  for (auto &components : lvlExprComponents) {
950  assert(components.second.size() == 3 &&
951  "expected 3 components to build lvlExprs");
952  auto mulOp = getAffineBinaryOpExpr(
953  AffineExprKind::Mul, components.second[0], components.second[1]);
954  auto addOp =
955  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
956  lvlExprs.push_back(addOp);
957  }
958  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
959 }
960 
962  assert(isBlockSparsity(dimToLvl) &&
963  "expected dimToLvl to be block sparsity for calling getBlockSize");
964  SmallVector<unsigned> blockSize;
965  for (auto result : dimToLvl.getResults()) {
966  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
967  if (result.getKind() == AffineExprKind::Mod) {
968  blockSize.push_back(
969  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
970  }
971  } else {
972  blockSize.push_back(0);
973  }
974  }
975  return blockSize;
976 }
977 
979  if (!dimToLvl)
980  return false;
981  std::map<unsigned, int64_t> coeffientMap;
982  bool hasBlock = false;
983  for (auto result : dimToLvl.getResults()) {
984  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
985  // Check for "dim op const".
986  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
987  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
988  if (!dimOp || !conOp || conOp.getValue() <= 0)
989  return false;
990  // Inspect "dim / const" or "dim % const".
991  auto pos = dimOp.getPosition();
992  if (binOp.getKind() == AffineExprKind::FloorDiv) {
993  // Expect only one floordiv for each dimension.
994  if (coeffientMap.find(pos) != coeffientMap.end())
995  return false;
996  // Record coefficient of the floordiv.
997  coeffientMap[pos] = conOp.getValue();
998  } else if (binOp.getKind() == AffineExprKind::Mod) {
999  // Expect floordiv before mod.
1000  if (coeffientMap.find(pos) == coeffientMap.end())
1001  return false;
1002  // Expect mod to have the same coefficient as floordiv.
1003  if (conOp.getValue() != coeffientMap[pos])
1004  return false;
1005  hasBlock = true;
1006  } else {
1007  return false;
1008  }
1009  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1010  auto pos = dimOp.getPosition();
1011  // Expect dim to be unset.
1012  if (coeffientMap.find(pos) != coeffientMap.end())
1013  return false;
1014  coeffientMap[pos] = 0;
1015  } else {
1016  return false;
1017  }
1018  }
1019  return hasBlock;
1020 }
1021 
1023  auto hasNonIdentityMap = [](Value v) {
1024  auto stt = tryGetSparseTensorType(v);
1025  return stt && !stt->isIdentity();
1026  };
1027 
1028  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1029  llvm::any_of(op->getResults(), hasNonIdentityMap);
1030 }
1031 
1032 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1033  if (enc) {
1034  assert(enc.isPermutation() && "Non permutation map not supported");
1035  if (const auto dimToLvl = enc.getDimToLvl())
1036  return dimToLvl.getDimPosition(l);
1037  }
1038  return l;
1039 }
1040 
1041 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1042  if (enc) {
1043  assert(enc.isPermutation() && "Non permutation map not supported");
1044  if (const auto lvlToDim = enc.getLvlToDim())
1045  return lvlToDim.getDimPosition(d);
1046  }
1047  return d;
1048 }
1049 
1050 /// We normalized sparse tensor encoding attribute by always using
1051 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1052 /// as other variants) lead to the same storage specifier type, and stripping
1053 /// irrelevant fields that do not alter the sparse tensor memory layout.
1054 static SparseTensorEncodingAttr
1055 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1057  for (auto lt : enc.getLvlTypes())
1058  lts.push_back(lt.stripStorageIrrelevantProperties());
1059 
1061  enc.getContext(), lts,
1062  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1063  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1064  // Always use `index` for memSize and lvlSize instead of reusing
1065  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1066  // value for different bitwidth, it also avoids casting between index and
1067  // integer (returned by DimOp)
1068  0, 0, enc.getDimSlices());
1069 }
1070 
1071 StorageSpecifierType
1072 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1073  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // SparseTensorDialect Operations.
1078 //===----------------------------------------------------------------------===//
1079 
1081  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1082 }
1083 
1084 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1085  const Type etp = getMemRefType(mem).getElementType();
1086  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1087 }
1088 
1090  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1092  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1093  return op->emitError(
1094  "redundant level argument for querying value memory size");
1095  }
1096 
1097  const auto enc = md.getType().getEncoding();
1098  const Level lvlRank = enc.getLvlRank();
1099 
1100  if (mdKind == StorageSpecifierKind::DimOffset ||
1101  mdKind == StorageSpecifierKind::DimStride)
1102  if (!enc.isSlice())
1103  return op->emitError("requested slice data on non-slice tensor");
1104 
1105  if (mdKind != StorageSpecifierKind::ValMemSize) {
1106  if (!lvl)
1107  return op->emitError("missing level argument");
1108 
1109  const Level l = lvl.value();
1110  if (l >= lvlRank)
1111  return op->emitError("requested level is out of bounds");
1112 
1113  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1114  return op->emitError(
1115  "requested position memory size on a singleton level");
1116  }
1117  return success();
1118 }
1119 
1121  switch (kind) {
1123  return stt.getCrdType();
1125  return stt.getPosType();
1127  return stt.getElementType();
1129  return nullptr;
1130  }
1131  llvm_unreachable("Unrecognizable FieldKind");
1132 }
1133 
1134 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1135  SparseTensorType stt,
1136  RankedTensorType valTp,
1137  TypeRange lvlTps) {
1138  if (requiresStaticShape && !stt.hasStaticDimShape())
1139  return op->emitError("the sparse-tensor must have static shape");
1140  if (!stt.hasEncoding())
1141  return op->emitError("the sparse-tensor must have an encoding attribute");
1142 
1143  // Verifies the trailing COO.
1144  Level cooStartLvl = stt.getAoSCOOStart();
1145  if (cooStartLvl < stt.getLvlRank()) {
1146  // We only supports trailing COO for now, must be the last input.
1147  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1148  // The coordinates should be in shape of <? x rank>
1149  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1150  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1151  op->emitError("input/output trailing COO level-ranks don't match");
1152  }
1153  }
1154 
1155  // Verifies that all types match.
1156  StorageLayout layout(stt.getEncoding());
1157  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1158  return op->emitError("inconsistent number of fields between input/output");
1159 
1160  unsigned idx = 0;
1161  bool misMatch = false;
1162  layout.foreachField([&idx, &misMatch, stt, valTp,
1163  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1164  Level lvl, LevelType lt) -> bool {
1166  return true;
1167 
1168  Type inputTp = nullptr;
1169  if (fKind == SparseTensorFieldKind::ValMemRef) {
1170  inputTp = valTp;
1171  } else {
1172  assert(fid == idx && stt.getLvlType(lvl) == lt);
1173  inputTp = lvlTps[idx++];
1174  }
1175  // The input element type and expected element type should match.
1176  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1177  Type expElemTp = getFieldElemType(stt, fKind);
1178  if (inpElemTp != expElemTp) {
1179  misMatch = true;
1180  return false; // to terminate the iteration
1181  }
1182  return true;
1183  });
1184 
1185  if (misMatch)
1186  return op->emitError("input/output element-types don't match");
1187  return success();
1188 }
1189 
1191  const auto valuesTp = getRankedTensorType(getValues());
1192  const auto lvlsTp = getLevels().getTypes();
1193  const auto resTp = getSparseTensorType(getResult());
1194  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1195 }
1196 
1198  if (getOutValues().getType() != getRetValues().getType())
1199  return emitError("output values and return value type mismatch");
1200 
1201  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1202  if (ot.getType() != rt.getType())
1203  return emitError("output levels and return levels type mismatch");
1204 
1205  const auto valuesTp = getRankedTensorType(getRetValues());
1206  const auto lvlsTp = getRetLevels().getTypes();
1207  const auto srcTp = getSparseTensorType(getTensor());
1208  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1209 }
1210 
1212  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1213  if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1214  if (tp1.getRank() != tp2.getRank())
1215  return emitError("unexpected conversion mismatch in rank");
1216  auto dstEnc =
1217  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1218  if (dstEnc && dstEnc.isSlice())
1219  return emitError("cannot convert to a sparse tensor slice");
1220 
1221  auto shape1 = tp1.getShape();
1222  auto shape2 = tp2.getShape();
1223  // Accept size matches between the source and the destination type
1224  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1225  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1226  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1227  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1228  return emitError("unexpected conversion mismatch in dimension ") << d;
1229  return success();
1230  }
1231  }
1232  return emitError("unexpected type in convert");
1233 }
1234 
1235 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1236  if (getType() == getSource().getType())
1237  return getSource();
1238  return {};
1239 }
1240 
1241 bool ConvertOp::needsExtraSort() {
1242  SparseTensorType srcStt = getSparseTensorType(getSource());
1243  SparseTensorType dstStt = getSparseTensorType(getDest());
1244 
1245  // We do not need an extra sort when returning unordered sparse tensors or
1246  // dense tensor since dense tensor support random access.
1247  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1248  return false;
1249 
1250  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1251  srcStt.hasSameDimToLvl(dstStt)) {
1252  return false;
1253  }
1254 
1255  // Source and dest tensors are ordered in different ways. We only do direct
1256  // dense to sparse conversion when the dense input is defined by a sparse
1257  // constant. Note that we can theoretically always directly convert from dense
1258  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1259  // performance.
1260  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1261  if (isa<SparseElementsAttr>(constOp.getValue()))
1262  return false;
1263 
1264  return true;
1265 }
1266 
1268  uint64_t inRank = getEncoder().getLvlRank();
1269  uint64_t outRank = getEncoder().getDimRank();
1270 
1271  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1272  std::swap(inRank, outRank);
1273 
1274  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1275  return emitError("Coordinate rank mismatch with encoding");
1276 
1277  return success();
1278 }
1279 
1280 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1281  SmallVectorImpl<OpFoldResult> &results) {
1282  if (getEncoder().isIdentity()) {
1283  results.assign(getInCrds().begin(), getInCrds().end());
1284  return success();
1285  }
1286  if (getEncoder().isPermutation()) {
1287  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1288  ? getEncoder().getDimToLvl()
1289  : getEncoder().getLvlToDim();
1290  for (AffineExpr exp : perm.getResults())
1291  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1292  return success();
1293  }
1294 
1295  // Fuse dim2lvl/lvl2dim pairs.
1296  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1297  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1298  return v.getDefiningOp() == def;
1299  });
1300  if (!sameDef)
1301  return failure();
1302 
1303  bool oppositeDir = def.getDirection() != getDirection();
1304  bool sameOracle =
1305  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1306  bool sameCount = def.getNumResults() == getInCrds().size();
1307  if (!oppositeDir || !sameOracle || !sameCount)
1308  return failure();
1309 
1310  // The definition produces the coordinates in the same order as the input
1311  // coordinates.
1312  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1313  [](auto valuePair) {
1314  auto [lhs, rhs] = valuePair;
1315  return lhs == rhs;
1316  });
1317 
1318  if (!sameOrder)
1319  return failure();
1320  // l1 = dim2lvl (lvl2dim l0)
1321  // ==> l0
1322  results.append(def.getInCrds().begin(), def.getInCrds().end());
1323  return success();
1324 }
1325 
1326 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1327  int64_t index) {
1328  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1329  return build(builder, state, source, val);
1330 }
1331 
1333  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1334  auto stt = getSparseTensorType(getSource());
1335  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1336  emitError("Level index exceeds the rank of the input sparse tensor");
1337  }
1338  return success();
1339 }
1340 
1341 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1342  return getConstantIntValue(getIndex());
1343 }
1344 
1345 Speculation::Speculatability LvlOp::getSpeculatability() {
1346  auto constantIndex = getConstantLvlIndex();
1347  if (!constantIndex)
1349 
1350  assert(constantIndex <
1351  cast<RankedTensorType>(getSource().getType()).getRank());
1353 }
1354 
1355 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1356  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1357  if (!lvlIndex)
1358  return {};
1359 
1360  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1361  auto stt = getSparseTensorType(getSource());
1362  if (lvl >= stt.getLvlRank()) {
1363  // Follows the same convention used by tensor.dim operation. Out of bound
1364  // indices produce undefined behavior but are still valid IR. Don't choke on
1365  // them.
1366  return {};
1367  }
1368 
1369  // Helper lambda to build an IndexAttr.
1370  auto getIndexAttr = [this](int64_t lvlSz) {
1371  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1372  };
1373 
1374  SmallVector<Size> lvlShape = stt.getLvlShape();
1375  if (!ShapedType::isDynamic(lvlShape[lvl]))
1376  return getIndexAttr(lvlShape[lvl]);
1377 
1378  return {};
1379 }
1380 
1381 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1382  SparseTensorEncodingAttr dstEnc, Value source) {
1383  auto srcStt = getSparseTensorType(source);
1384  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1385  SmallVector<int64_t> dstDimShape =
1386  dstEnc.tranlateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1387  auto dstTp =
1388  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1389  return build(odsBuilder, odsState, dstTp, source);
1390 }
1391 
1393  auto srcStt = getSparseTensorType(getSource());
1394  auto dstStt = getSparseTensorType(getDest());
1395  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1396  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1397 
1398  if (srcLvlTps.size() != dstLvlTps.size())
1399  return emitError("Level rank mismatch between source/dest tensors");
1400 
1401  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1402  if (srcLvlTp != dstLvlTp)
1403  return emitError("Level type mismatch between source/dest tensors");
1404 
1405  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1406  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1407  return emitError("Crd/Pos width mismatch between source/dest tensors");
1408  }
1409 
1410  if (srcStt.getElementType() != dstStt.getElementType())
1411  return emitError("Element type mismatch between source/dest tensors");
1412 
1413  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1414  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1415  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1416  if (srcLvlSz != dstLvlSz) {
1417  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1418  // compatible to <3x4>? For now, we require all the level sizes to be
1419  // *exactly* matched for simplicity.
1420  return emitError("Level size mismatch between source/dest tensors");
1421  }
1422  }
1423 
1424  return success();
1425 }
1426 
1427 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1428  if (getSource().getType() == getDest().getType())
1429  return getSource();
1430 
1431  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1432  // A -> B, B -> A ==> A
1433  if (def.getSource().getType() == getDest().getType())
1434  return def.getSource();
1435  }
1436  return {};
1437 }
1438 
1440  auto stt = getSparseTensorType(getTensor());
1441  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1442  return emitError("requested level is out of bounds");
1443  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1444  return emitError("unexpected type for positions");
1445  return success();
1446 }
1447 
1449  auto stt = getSparseTensorType(getTensor());
1450  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1451  return emitError("requested level is out of bounds");
1452  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1453  return emitError("unexpected type for coordinates");
1454  return success();
1455 }
1456 
1458  auto stt = getSparseTensorType(getTensor());
1459  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1460  return emitError("expected sparse tensor with a COO region");
1461  return success();
1462 }
1463 
1465  auto stt = getSparseTensorType(getTensor());
1466  auto mtp = getMemRefType(getResult());
1467  if (stt.getElementType() != mtp.getElementType())
1468  return emitError("unexpected mismatch in element types");
1469  return success();
1470 }
1471 
1473  auto rank = getRankedTensorType(getSlice()).getRank();
1474  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1475  return emitError("requested dimension out of bound");
1476  return success();
1477 }
1478 
1480  auto rank = getRankedTensorType(getSlice()).getRank();
1481  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1482  return emitError("requested dimension out of bound");
1483  return success();
1484 }
1485 
1487  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1488  getSpecifier(), getOperation());
1489 }
1490 
1491 template <typename SpecifierOp>
1492 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1493  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1494 }
1495 
1496 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1497  const StorageSpecifierKind kind = getSpecifierKind();
1498  const auto lvl = getLevel();
1499  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1500  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1501  return op.getValue();
1502  return {};
1503 }
1504 
1506  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1507  getSpecifier(), getOperation());
1508 }
1509 
1510 template <class T>
1512  const char *regionName,
1513  TypeRange inputTypes, Type outputType) {
1514  unsigned numArgs = region.getNumArguments();
1515  unsigned expectedNum = inputTypes.size();
1516  if (numArgs != expectedNum)
1517  return op->emitError() << regionName << " region must have exactly "
1518  << expectedNum << " arguments";
1519 
1520  for (unsigned i = 0; i < numArgs; i++) {
1521  Type typ = region.getArgument(i).getType();
1522  if (typ != inputTypes[i])
1523  return op->emitError() << regionName << " region argument " << (i + 1)
1524  << " type mismatch";
1525  }
1526  Operation *term = region.front().getTerminator();
1527  YieldOp yield = dyn_cast<YieldOp>(term);
1528  if (!yield)
1529  return op->emitError() << regionName
1530  << " region must end with sparse_tensor.yield";
1531  if (!yield.getResult() || yield.getResult().getType() != outputType)
1532  return op->emitError() << regionName << " region yield type mismatch";
1533 
1534  return success();
1535 }
1536 
1538  NamedAttrList attrs = (*this)->getAttrs();
1539  Type leftType = getX().getType();
1540  Type rightType = getY().getType();
1541  Type outputType = getOutput().getType();
1542  Region &overlap = getOverlapRegion();
1543  Region &left = getLeftRegion();
1544  Region &right = getRightRegion();
1545 
1546  // Check correct number of block arguments and return type for each
1547  // non-empty region.
1548  if (!overlap.empty()) {
1549  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1550  TypeRange{leftType, rightType}, outputType)))
1551  return failure();
1552  }
1553  if (!left.empty()) {
1554  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1555  outputType)))
1556  return failure();
1557  } else if (getLeftIdentity()) {
1558  if (leftType != outputType)
1559  return emitError("left=identity requires first argument to have the same "
1560  "type as the output");
1561  }
1562  if (!right.empty()) {
1563  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1564  outputType)))
1565  return failure();
1566  } else if (getRightIdentity()) {
1567  if (rightType != outputType)
1568  return emitError("right=identity requires second argument to have the "
1569  "same type as the output");
1570  }
1571  return success();
1572 }
1573 
1575  Type inputType = getX().getType();
1576  Type outputType = getOutput().getType();
1577 
1578  // Check correct number of block arguments and return type for each
1579  // non-empty region.
1580  Region &present = getPresentRegion();
1581  if (!present.empty()) {
1582  if (failed(verifyNumBlockArgs(this, present, "present",
1583  TypeRange{inputType}, outputType)))
1584  return failure();
1585  }
1586  Region &absent = getAbsentRegion();
1587  if (!absent.empty()) {
1588  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1589  outputType)))
1590  return failure();
1591  // Absent branch can only yield invariant values.
1592  Block *absentBlock = &absent.front();
1593  Block *parent = getOperation()->getBlock();
1594  Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
1595  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1596  if (arg.getOwner() == parent)
1597  return emitError("absent region cannot yield linalg argument");
1598  } else if (Operation *def = absentVal.getDefiningOp()) {
1599  if (!isa<arith::ConstantOp>(def) &&
1600  (def->getBlock() == absentBlock || def->getBlock() == parent))
1601  return emitError("absent region cannot yield locally computed value");
1602  }
1603  }
1604  return success();
1605 }
1606 
1607 bool ConcatenateOp::needsExtraSort() {
1608  SparseTensorType dstStt = getSparseTensorType(*this);
1609  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1610  return false;
1611 
1612  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1613  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1614  });
1615  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1616  // in all input/output buffers, and all input/output buffers have the same
1617  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1618  // CSC matrices along column).
1619  bool directLowerable =
1620  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1621  return !directLowerable;
1622 }
1623 
1625  const auto dstTp = getSparseTensorType(*this);
1626  const Dimension concatDim = getDimension();
1627  const Dimension dimRank = dstTp.getDimRank();
1628 
1629  if (getInputs().size() <= 1)
1630  return emitError("Need at least two tensors to concatenate.");
1631 
1632  if (concatDim >= dimRank)
1633  return emitError(llvm::formatv(
1634  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1635  concatDim, dimRank));
1636 
1637  for (const auto &it : llvm::enumerate(getInputs())) {
1638  const auto i = it.index();
1639  const auto srcTp = getSparseTensorType(it.value());
1640  if (srcTp.hasDynamicDimShape())
1641  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1642  const Dimension srcDimRank = srcTp.getDimRank();
1643  if (srcDimRank != dimRank)
1644  return emitError(
1645  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1646  "from the output tensor (rank={2}).",
1647  i, srcDimRank, dimRank));
1648  }
1649 
1650  for (Dimension d = 0; d < dimRank; d++) {
1651  const Size dstSh = dstTp.getDimShape()[d];
1652  if (d == concatDim) {
1653  if (!ShapedType::isDynamic(dstSh)) {
1654  // If we reach here, then all inputs have static shapes. So we
1655  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1656  // to avoid redundant assertions in the loop.
1657  Size sumSz = 0;
1658  for (const auto src : getInputs())
1659  sumSz += getSparseTensorType(src).getDimShape()[d];
1660  // If all dimension are statically known, the sum of all the input
1661  // dimensions should be equal to the output dimension.
1662  if (sumSz != dstSh)
1663  return emitError(
1664  "The concatenation dimension of the output tensor should be the "
1665  "sum of all the concatenation dimensions of the input tensors.");
1666  }
1667  } else {
1668  Size prev = dstSh;
1669  for (const auto src : getInputs()) {
1670  const auto sh = getSparseTensorType(src).getDimShape()[d];
1671  if (!ShapedType::isDynamic(prev) && sh != prev)
1672  return emitError("All dimensions (expect for the concatenating one) "
1673  "should be equal.");
1674  prev = sh;
1675  }
1676  }
1677  }
1678 
1679  return success();
1680 }
1681 
1683  const auto stt = getSparseTensorType(getTensor());
1684  if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
1685  return emitOpError("incorrect number of coordinates");
1686  return success();
1687 }
1688 
1689 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1690  Value curSize, Value inBuffer, Value value) {
1691  build(builder, result, curSize, inBuffer, value, Value());
1692 }
1693 
1695  if (Value n = getN()) {
1696  std::optional<int64_t> nValue = getConstantIntValue(n);
1697  if (nValue && nValue.value() < 1)
1698  return emitOpError("n must be not less than 1");
1699  }
1700  return success();
1701 }
1702 
1704  const auto stt = getSparseTensorType(getTensor());
1705  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1706  return emitOpError("incorrect number of coordinates");
1707  return success();
1708 }
1709 
1710 void ForeachOp::build(
1711  OpBuilder &builder, OperationState &result, Value tensor,
1712  ValueRange initArgs, AffineMapAttr order,
1714  bodyBuilder) {
1715  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1716  // Builds foreach body.
1717  if (!bodyBuilder)
1718  return;
1719  const auto stt = getSparseTensorType(tensor);
1720  const Dimension dimRank = stt.getDimRank();
1721 
1722  // Starts with `dimRank`-many coordinates.
1723  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1724  // Followed by one value.
1725  blockArgTypes.push_back(stt.getElementType());
1726  // Followed by the reduction variables.
1727  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1728 
1729  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1730 
1731  OpBuilder::InsertionGuard guard(builder);
1732  auto &region = *result.regions.front();
1733  Block *bodyBlock =
1734  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1735  bodyBuilder(builder, result.location,
1736  bodyBlock->getArguments().slice(0, dimRank),
1737  bodyBlock->getArguments()[dimRank],
1738  bodyBlock->getArguments().drop_front(dimRank + 1));
1739 }
1740 
1742  const auto t = getSparseTensorType(getTensor());
1743  const Dimension dimRank = t.getDimRank();
1744  const auto args = getBody()->getArguments();
1745 
1746  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1747  return emitError("Level traverse order does not match tensor's level rank");
1748 
1749  if (dimRank + 1 + getInitArgs().size() != args.size())
1750  return emitError("Unmatched number of arguments in the block");
1751 
1752  if (getNumResults() != getInitArgs().size())
1753  return emitError("Mismatch in number of init arguments and results");
1754 
1755  if (getResultTypes() != getInitArgs().getTypes())
1756  return emitError("Mismatch in types of init arguments and results");
1757 
1758  // Cannot mark this const, because the getters aren't.
1759  auto yield = cast<YieldOp>(getBody()->getTerminator());
1760  if (yield.getNumOperands() != getNumResults() ||
1761  yield.getOperands().getTypes() != getResultTypes())
1762  return emitError("Mismatch in types of yield values and results");
1763 
1764  const auto iTp = IndexType::get(getContext());
1765  for (Dimension d = 0; d < dimRank; d++)
1766  if (args[d].getType() != iTp)
1767  emitError(
1768  llvm::formatv("Expecting Index type for argument at index {0}", d));
1769 
1770  const auto elemTp = t.getElementType();
1771  const auto valueTp = args[dimRank].getType();
1772  if (elemTp != valueTp)
1773  emitError(llvm::formatv("Unmatched element type between input tensor and "
1774  "block argument, expected:{0}, got: {1}",
1775  elemTp, valueTp));
1776  return success();
1777 }
1778 
1779 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1780  if (getSparseTensorEncoding(getInputCoo().getType()) ==
1781  getSparseTensorEncoding(getResultCoo().getType()))
1782  return getInputCoo();
1783 
1784  return {};
1785 }
1786 
1788  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1789  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1790 
1791  if (!srcStt.isCOOType() || !dstStt.isCOOType())
1792  emitError("Expected COO sparse tensors only");
1793 
1794  if (!srcStt.hasSameDimToLvl(dstStt))
1795  emitError("Unmatched dim2lvl map between input and result COO");
1796 
1797  if (srcStt.getPosType() != dstStt.getPosType() ||
1798  srcStt.getCrdType() != dstStt.getCrdType() ||
1799  srcStt.getElementType() != dstStt.getElementType())
1800  emitError("Unmatched storage format between input and result COO");
1801 
1802  return success();
1803 }
1804 
1806  Type inputType = getX().getType();
1807  Region &formula = getRegion();
1808  return verifyNumBlockArgs(this, formula, "reduce",
1809  TypeRange{inputType, inputType}, inputType);
1810 }
1811 
1813  Builder b(getContext());
1814  Type inputType = getX().getType();
1815  Type boolType = b.getI1Type();
1816  Region &formula = getRegion();
1817  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1818  boolType);
1819 }
1820 
1822  AffineMap xPerm = getPermMap();
1823  uint64_t nx = xPerm.getNumDims();
1824  if (nx < 1)
1825  emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1826 
1827  if (!xPerm.isPermutation())
1828  emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1829 
1830  // We can't check the size of the buffers when n or buffer dimensions aren't
1831  // compile-time constants.
1832  std::optional<int64_t> cn = getConstantIntValue(getN());
1833  if (!cn)
1834  return success();
1835 
1836  // Verify dimensions.
1837  const auto checkDim = [&](Value v, Size minSize, const char *message) {
1838  const Size sh = getMemRefType(v).getShape()[0];
1839  if (!ShapedType::isDynamic(sh) && sh < minSize)
1840  emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1841  };
1842  uint64_t n = cn.value();
1843  uint64_t ny = 0;
1844  if (auto nyAttr = getNyAttr())
1845  ny = nyAttr.getInt();
1846  checkDim(getXy(), n * (nx + ny),
1847  "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1848  for (Value opnd : getYs())
1849  checkDim(opnd, n, "Expected dimension(y) >= n");
1850 
1851  return success();
1852 }
1853 
1855  // Check for compatible parent.
1856  auto *parentOp = (*this)->getParentOp();
1857  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
1858  isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
1859  isa<ForeachOp>(parentOp))
1860  return success();
1861 
1862  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1863  "reduce, select or foreach");
1864 }
1865 
1866 /// Materialize a single constant operation from a given attribute value with
1867 /// the desired resultant type.
1869  Attribute value, Type type,
1870  Location loc) {
1871  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
1872  return op;
1873  return nullptr;
1874 }
1875 
1876 namespace {
1877 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
1879 
1880  AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
1881  if (attr.isa<SparseTensorEncodingAttr>()) {
1882  os << "sparse";
1883  return AliasResult::OverridableAlias;
1884  }
1885  return AliasResult::NoAlias;
1886  }
1887 };
1888 } // namespace
1889 
1890 void SparseTensorDialect::initialize() {
1891  addInterface<SparseTensorAsmDialectInterface>();
1892  addAttributes<
1893 #define GET_ATTRDEF_LIST
1894 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1895  >();
1896  addTypes<
1897 #define GET_TYPEDEF_LIST
1898 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1899  >();
1900  addOperations<
1901 #define GET_OP_LIST
1902 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1903  >();
1904 }
1905 
1906 #define GET_OP_CLASSES
1907 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1908 
1909 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:112
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static constexpr Level kInvalidFieldIndex
std::string getNOutOfMString(LevelType lt)
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
Base type for affine expression.
Definition: AffineExpr.h:69
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:401
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:353
unsigned getNumSymbols() const
Definition: AffineMap.cpp:384
unsigned getNumDims() const
Definition: AffineMap.cpp:380
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:393
unsigned getNumResults() const
Definition: AffineMap.cpp:388
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:397
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:611
The possible results of an alias query.
Definition: AliasAnalysis.h:26
@ NoAlias
The two locations do not alias at all.
Definition: AliasAnalysis.h:34
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
bool isa() const
Casting utility functions.
Definition: Attributes.h:169
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IndexType getIndexType()
Definition: Builders.cpp:71
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the Level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
SmallVector< COOSegment > getCOOSegments() const
Returns a list of COO segments in the sparse tensor types.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
MPInt getIndex(ConeV cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:411
bool isWithPosLT(LevelType lt)
Definition: Enums.h:412
bool isOrderedLT(LevelType lt)
Definition: Enums.h:405
std::string toMLIRString(LevelType lt)
Definition: Enums.h:427
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:401
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
bool isCompressedLT(LevelType lt)
Definition: Enums.h:395
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:422
bool isLooseCompressedLT(LevelType lt)
Definition: Enums.h:398
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
llvm::hash_code hash_value(LevelType lt)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:74
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:82
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:393
uint64_t getM(LevelType lt)
Definition: Enums.h:423
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:394
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:382
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:404
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:755
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:62
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299