MLIR  19.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
21 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 
33 // Forward declarations, following custom print/parsing methods are referenced
34 // by the generated code for SparseTensorTypes.td.
40 
41 #define GET_TYPEDEF_CLASSES
42 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43 
44 using namespace mlir;
45 using namespace mlir::sparse_tensor;
46 
47 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48 // well.
49 namespace mlir::sparse_tensor {
50 llvm::hash_code hash_value(LevelType lt) {
51  return llvm::hash_value(static_cast<uint64_t>(lt));
52 }
53 } // namespace mlir::sparse_tensor
54 
55 //===----------------------------------------------------------------------===//
56 // Local Convenience Methods.
57 //===----------------------------------------------------------------------===//
58 
59 static constexpr bool acceptBitWidth(unsigned bitWidth) {
60  switch (bitWidth) {
61  case 0:
62  case 8:
63  case 16:
64  case 32:
65  case 64:
66  return true;
67  default:
68  return false;
69  }
70 }
71 
72 static SmallVector<Size>
73 getSparseFieldShape(const SparseTensorEncodingAttr enc,
74  std::optional<ArrayRef<int64_t>> dimShape) {
75  assert(enc);
76  // With only encoding, we can not determine the static shape for leading
77  // batch levels, we therefore return a dynamic shape memref instead.
78  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79  if (dimShape.has_value()) {
80  // If the actual tensor shape is provided, we can then refine the leading
81  // batch dimension.
82  SmallVector<int64_t> lvlShape =
83  enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84  memrefShape.assign(lvlShape.begin(),
85  lvlShape.begin() + enc.getBatchLvlRank());
86  }
87  // Another dynamic dimension to store the sparse level.
88  memrefShape.push_back(ShapedType::kDynamic);
89  return memrefShape;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // SparseTensorDialect StorageLayout.
94 //===----------------------------------------------------------------------===//
95 
96 static constexpr Level kInvalidLevel = -1u;
97 static constexpr Level kInvalidFieldIndex = -1u;
98 static constexpr FieldIndex kDataFieldStartingIdx = 0;
99 
102  LevelType)>
103  callback) const {
104  const auto lvlTypes = enc.getLvlTypes();
105  const Level lvlRank = enc.getLvlRank();
108 
109  ArrayRef cooSegsRef = cooSegs;
110  // Per-level storage.
111  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112  const auto lt = lvlTypes[l];
113  if (isWithPosLT(lt)) {
114  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115  return;
116  }
117  if (isWithCrdLT(lt)) {
118  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119  return;
120  }
121  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122  if (!cooSegsRef.front().isSoA) {
123  // AoS COO, all singletons are fused into one memrefs. Skips the entire
124  // COO segement.
125  l = cooSegsRef.front().lvlRange.second;
126  } else {
127  // SoA COO, each singleton level has one memref.
128  l++;
129  }
130  // Expire handled COO segment.
131  cooSegsRef = cooSegsRef.drop_front();
132  } else {
133  // Non COO levels.
134  l++;
135  }
136  }
137  // The values array.
138  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140  return;
141  // Put metadata at the end.
142  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144  return;
145 }
146 
148  SparseTensorType stt,
150  LevelType)>
151  callback) {
152  assert(stt.hasEncoding());
153 
154  SmallVector<int64_t> memrefShape =
156 
157  const Type specType = StorageSpecifierType::get(stt.getEncoding());
158  // memref<[batch] x ? x pos> positions
159  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160  // memref<[batch] x ? x crd> coordinates
161  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162  // memref<[batch] x ? x eltType> values
163  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164 
165  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166  callback](FieldIndex fieldIdx,
167  SparseTensorFieldKind fieldKind,
168  Level lvl, LevelType lt) -> bool {
169  switch (fieldKind) {
171  return callback(specType, fieldIdx, fieldKind, lvl, lt);
173  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178  };
179  llvm_unreachable("unrecognized field kind");
180  });
181 }
182 
183 unsigned StorageLayout::getNumFields() const {
184  unsigned numFields = 0;
186  LevelType) -> bool {
187  numFields++;
188  return true;
189  });
190  return numFields;
191 }
192 
194  unsigned numFields = 0; // one value memref
196  LevelType) -> bool {
197  if (fidx >= kDataFieldStartingIdx)
198  numFields++;
199  return true;
200  });
201  numFields -= 1; // the last field is StorageSpecifier
202  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203  return numFields;
204 }
205 
206 std::pair<FieldIndex, unsigned>
208  std::optional<Level> lvl) const {
209  FieldIndex fieldIdx = kInvalidFieldIndex;
210  unsigned stride = 1;
211  if (kind == SparseTensorFieldKind::CrdMemRef) {
212  assert(lvl.has_value());
213  const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
214  const Level lvlRank = enc.getLvlRank();
215  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216  lvl = cooStart;
217  stride = lvlRank - cooStart;
218  }
219  }
220  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221  SparseTensorFieldKind fKind, Level fLvl,
222  LevelType lt) -> bool {
223  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225  fieldIdx = fIdx;
226  // Returns false to break the iteration.
227  return false;
228  }
229  return true;
230  });
231  assert(fieldIdx != kInvalidFieldIndex);
232  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // SparseTensorDialect Attribute Methods.
237 //===----------------------------------------------------------------------===//
238 
239 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240  return isDynamic(v) ? std::nullopt
241  : std::make_optional(static_cast<uint64_t>(v));
242 }
243 
244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245  return getStatic(getOffset());
246 }
247 
248 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249  return getStatic(getStride());
250 }
251 
252 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253  return getStatic(getSize());
254 }
255 
256 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257  return isDynamic(getOffset()) && isDynamic(getStride()) &&
258  isDynamic(getSize());
259 }
260 
261 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262  return isDynamic(v) ? "?" : std::to_string(v);
263 }
264 
265 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267  os << '(';
268  os << getStaticString(getOffset());
269  os << ", ";
270  os << getStaticString(getSize());
271  os << ", ";
272  os << getStaticString(getStride());
273  os << ')';
274 }
275 
276 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277  print(printer.getStream());
278 }
279 
280 static ParseResult parseOptionalStaticSlice(int64_t &result,
281  AsmParser &parser) {
282  auto parseResult = parser.parseOptionalInteger(result);
283  if (parseResult.has_value()) {
284  if (parseResult.value().succeeded() && result < 0) {
285  parser.emitError(
286  parser.getCurrentLocation(),
287  "expect positive value or ? for slice offset/size/stride");
288  return failure();
289  }
290  return parseResult.value();
291  }
292 
293  // Else, and '?' which represented dynamic slice
294  result = SparseTensorDimSliceAttr::kDynamic;
295  return parser.parseQuestion();
296 }
297 
299  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300 
301  if (failed(parser.parseLParen()) ||
302  failed(parseOptionalStaticSlice(offset, parser)) ||
303  failed(parser.parseComma()) ||
304  failed(parseOptionalStaticSlice(size, parser)) ||
305  failed(parser.parseComma()) ||
306  failed(parseOptionalStaticSlice(stride, parser)) ||
307  failed(parser.parseRParen()))
308  return {};
309 
310  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311  offset, size, stride);
312 }
313 
316  int64_t offset, int64_t size, int64_t stride) {
317  if (!isDynamic(offset) && offset < 0)
318  return emitError() << "expect non-negative value or ? for slice offset";
319  if (!isDynamic(size) && size <= 0)
320  return emitError() << "expect positive value or ? for slice size";
321  if (!isDynamic(stride) && stride <= 0)
322  return emitError() << "expect positive value or ? for slice stride";
323  return success();
324 }
325 
326 SparseTensorEncodingAttr
327 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
330  getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331  getCrdWidth(), getExplicitVal(), getImplicitVal());
332 }
333 
334 SparseTensorEncodingAttr
335 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337 }
338 
339 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340  return withDimToLvl(AffineMap());
341 }
342 
343 SparseTensorEncodingAttr
344 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345  unsigned crdWidth) const {
346  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
348  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349  crdWidth, getExplicitVal(), getImplicitVal());
350 }
351 
352 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353  return withBitWidths(0, 0);
354 }
355 
356 SparseTensorEncodingAttr
357 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
360  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361  getCrdWidth(), explicitVal, getImplicitVal());
362 }
363 
364 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365  return withExplicitVal(Attribute());
366 }
367 
368 SparseTensorEncodingAttr
369 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
372  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373  getCrdWidth(), getExplicitVal(), implicitVal);
374 }
375 
376 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377  return withImplicitVal(Attribute());
378 }
379 
380 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
383  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384  getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385 }
386 
387 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389 }
390 
391 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392  ArrayRef<LevelType> lvlTypes = getLvlTypes();
393  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394  return std::distance(lastBatch, lvlTypes.rend());
395 }
396 
398  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399 }
400 
401 bool SparseTensorEncodingAttr::isAllOrdered() const {
402  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403 }
404 
405 Type SparseTensorEncodingAttr::getCrdElemType() const {
406  if (!getImpl())
407  return nullptr;
408  if (getCrdWidth())
409  return IntegerType::get(getContext(), getCrdWidth());
410  return IndexType::get(getContext());
411 }
412 
413 Type SparseTensorEncodingAttr::getPosElemType() const {
414  if (!getImpl())
415  return nullptr;
416  if (getPosWidth())
417  return IntegerType::get(getContext(), getPosWidth());
418  return IndexType::get(getContext());
419 }
420 
421 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422  std::optional<ArrayRef<int64_t>> dimShape) const {
423  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424  return MemRefType::get(shape, getCrdElemType());
425 }
426 
427 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428  std::optional<ArrayRef<int64_t>> dimShape) const {
429  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430  return MemRefType::get(shape, getPosElemType());
431 }
432 
433 bool SparseTensorEncodingAttr::isIdentity() const {
434  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435 }
436 
438  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439 }
440 
441 Dimension SparseTensorEncodingAttr::getDimRank() const {
442  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443  const auto dimToLvl = getDimToLvl();
444  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445 }
446 
447 Level SparseTensorEncodingAttr::getLvlRank() const {
448  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449  return getLvlTypes().size();
450 }
451 
452 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453  if (!getImpl())
454  return LevelFormat::Batch;
455  assert(l < getLvlRank() && "Level is out of bounds");
456  return getLvlTypes()[l];
457 }
458 
459 bool SparseTensorEncodingAttr::isSlice() const {
460  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461  return !getDimSlices().empty();
462 }
463 
464 SparseTensorDimSliceAttr
465 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466  assert(isSlice() && "Is not a slice");
467  const auto dimSlices = getDimSlices();
468  assert(dim < dimSlices.size() && "Dimension is out of bounds");
469  return dimSlices[dim];
470 }
471 
472 std::optional<uint64_t>
473 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474  return getDimSlice(dim).getStaticOffset();
475 }
476 
477 std::optional<uint64_t>
478 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479  return getDimSlice(dim).getStaticStride();
480 }
481 
482 std::optional<uint64_t>
483 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484  return getStaticDimSliceOffset(toDim(*this, lvl));
485 }
486 
487 std::optional<uint64_t>
488 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489  return getStaticDimSliceStride(toDim(*this, lvl));
490 }
491 
493 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494  CrdTransDirectionKind dir) const {
495  if (isIdentity())
496  return SmallVector<int64_t>(srcShape);
497 
499  unsigned rank =
500  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501  ret.reserve(rank);
502 
503  if (isPermutation()) {
504  for (unsigned r = 0; r < rank; r++) {
505  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506  : toLvl(*this, r);
507  ret.push_back(srcShape[trans]);
508  }
509  return ret;
510  }
511 
512  // Handle non-permutation maps.
513  AffineMap transMap =
514  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515 
517  dimRep.reserve(srcShape.size());
518  for (int64_t sz : srcShape) {
519  if (!ShapedType::isDynamic(sz)) {
520  // Push back the max coordinate for the given dimension/level size.
521  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
522  } else {
523  // A dynamic size, use a AffineDimExpr to symbolize the value.
524  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
525  }
526  };
527 
528  for (AffineExpr exp : transMap.getResults()) {
529  // Do constant propagation on the affine map.
530  AffineExpr evalExp =
531  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
532  // use llvm namespace here to avoid ambiguity
533  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534  ret.push_back(c.getValue() + 1);
535  } else {
536  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
537  mod && mod.getKind() == AffineExprKind::Mod) {
538  // We can still infer a static bound for expressions in form
539  // "d % constant" since d % constant \in [0, constant).
540  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541  ret.push_back(bound.getValue());
542  continue;
543  }
544  }
545  ret.push_back(ShapedType::kDynamic);
546  }
547  }
548  assert(ret.size() == rank);
549  return ret;
550 }
551 
553 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
554  ValueRange crds,
555  CrdTransDirectionKind dir) const {
556  if (!getImpl())
557  return crds;
558 
559  SmallVector<Type> retType(
560  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
561  builder.getIndexType());
562  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
563  return transOp.getOutCrds();
564 }
565 
567  // Open "<{" part.
568  if (failed(parser.parseLess()))
569  return {};
570  if (failed(parser.parseLBrace()))
571  return {};
572 
573  // Process the data from the parsed dictionary value into struct-like data.
574  SmallVector<LevelType> lvlTypes;
576  AffineMap dimToLvl = {};
577  AffineMap lvlToDim = {};
578  unsigned posWidth = 0;
579  unsigned crdWidth = 0;
580  Attribute explicitVal;
581  Attribute implicitVal;
582  StringRef attrName;
583  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
584  "explicitVal", "implicitVal"};
585  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
586  // Detect admissible keyword.
587  auto *it = find(keys, attrName);
588  if (it == keys.end()) {
589  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
590  return {};
591  }
592  unsigned keyWordIndex = it - keys.begin();
593  // Consume the `=` after keys
594  if (failed(parser.parseEqual()))
595  return {};
596  // Dispatch on keyword.
597  switch (keyWordIndex) {
598  case 0: { // map
599  ir_detail::DimLvlMapParser cParser(parser);
600  auto res = cParser.parseDimLvlMap();
601  if (failed(res))
602  return {};
603  const auto &dlm = *res;
604 
605  const Level lvlRank = dlm.getLvlRank();
606  for (Level lvl = 0; lvl < lvlRank; lvl++)
607  lvlTypes.push_back(dlm.getLvlType(lvl));
608 
609  const Dimension dimRank = dlm.getDimRank();
610  for (Dimension dim = 0; dim < dimRank; dim++)
611  dimSlices.push_back(dlm.getDimSlice(dim));
612  // NOTE: the old syntax requires an all-or-nothing approach to
613  // `dimSlices`; therefore, if any slice actually exists then we need
614  // to convert null-DSA into default/nop DSA.
615  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
616  return static_cast<bool>(slice.getImpl());
617  };
618  if (llvm::any_of(dimSlices, isDefined)) {
619  const auto defaultSlice =
621  for (Dimension dim = 0; dim < dimRank; dim++)
622  if (!isDefined(dimSlices[dim]))
623  dimSlices[dim] = defaultSlice;
624  } else {
625  dimSlices.clear();
626  }
627 
628  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
629  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
630  break;
631  }
632  case 1: { // posWidth
633  Attribute attr;
634  if (failed(parser.parseAttribute(attr)))
635  return {};
636  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
637  if (!intAttr) {
638  parser.emitError(parser.getNameLoc(),
639  "expected an integral position bitwidth");
640  return {};
641  }
642  posWidth = intAttr.getInt();
643  break;
644  }
645  case 2: { // crdWidth
646  Attribute attr;
647  if (failed(parser.parseAttribute(attr)))
648  return {};
649  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
650  if (!intAttr) {
651  parser.emitError(parser.getNameLoc(),
652  "expected an integral index bitwidth");
653  return {};
654  }
655  crdWidth = intAttr.getInt();
656  break;
657  }
658  case 3: { // explicitVal
659  Attribute attr;
660  if (failed(parser.parseAttribute(attr)))
661  return {};
662  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
663  explicitVal = result;
664  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
665  explicitVal = result;
666  } else {
667  parser.emitError(parser.getNameLoc(),
668  "expected a numeric value for explicitVal");
669  return {};
670  }
671  break;
672  }
673  case 4: { // implicitVal
674  Attribute attr;
675  if (failed(parser.parseAttribute(attr)))
676  return {};
677  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
678  implicitVal = result;
679  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
680  implicitVal = result;
681  } else {
682  parser.emitError(parser.getNameLoc(),
683  "expected a numeric value for implicitVal");
684  return {};
685  }
686  break;
687  }
688  } // switch
689  // Only last item can omit the comma.
690  if (parser.parseOptionalComma().failed())
691  break;
692  }
693 
694  // Close "}>" part.
695  if (failed(parser.parseRBrace()))
696  return {};
697  if (failed(parser.parseGreater()))
698  return {};
699 
700  // Construct struct-like storage for attribute.
701  if (!lvlToDim || lvlToDim.isEmpty()) {
702  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
703  }
704  return parser.getChecked<SparseTensorEncodingAttr>(
705  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
706  explicitVal, implicitVal, dimSlices);
707 }
708 
709 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
710  auto map = static_cast<AffineMap>(getDimToLvl());
711  // Empty affine map indicates identity map
712  if (!map)
713  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
714  printer << "<{ map = ";
715  printSymbols(map, printer);
716  printer << '(';
717  printDimensions(map, printer, getDimSlices());
718  printer << ") -> (";
719  printLevels(map, printer, getLvlTypes());
720  printer << ')';
721  // Print remaining members only for non-default values.
722  if (getPosWidth())
723  printer << ", posWidth = " << getPosWidth();
724  if (getCrdWidth())
725  printer << ", crdWidth = " << getCrdWidth();
726  if (getExplicitVal()) {
727  printer << ", explicitVal = " << getExplicitVal();
728  }
729  if (getImplicitVal())
730  printer << ", implicitVal = " << getImplicitVal();
731  printer << " }>";
732 }
733 
734 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
735  AsmPrinter &printer) const {
736  if (map.getNumSymbols() == 0)
737  return;
738  printer << '[';
739  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
740  printer << 's' << i << ", ";
741  if (map.getNumSymbols() >= 1)
742  printer << 's' << map.getNumSymbols() - 1;
743  printer << ']';
744 }
745 
746 void SparseTensorEncodingAttr::printDimensions(
747  AffineMap &map, AsmPrinter &printer,
748  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
749  if (!dimSlices.empty()) {
750  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
751  printer << 'd' << i << " : " << dimSlices[i] << ", ";
752  if (map.getNumDims() >= 1) {
753  printer << 'd' << map.getNumDims() - 1 << " : "
754  << dimSlices[map.getNumDims() - 1];
755  }
756  } else {
757  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
758  printer << 'd' << i << ", ";
759  if (map.getNumDims() >= 1)
760  printer << 'd' << map.getNumDims() - 1;
761  }
762 }
763 
764 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
765  ArrayRef<LevelType> lvlTypes) const {
766  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
767  map.getResult(i).print(printer.getStream());
768  printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
769  }
770  if (map.getNumResults() >= 1) {
771  auto lastIndex = map.getNumResults() - 1;
772  map.getResult(lastIndex).print(printer.getStream());
773  printer << " : " << toMLIRString(lvlTypes[lastIndex]);
774  }
775 }
776 
779  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
780  unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
782  if (!acceptBitWidth(posWidth))
783  return emitError() << "unexpected position bitwidth: " << posWidth;
784  if (!acceptBitWidth(crdWidth))
785  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
786  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
787  it != std::end(lvlTypes)) {
788  if (it == lvlTypes.begin() ||
789  (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
790  return emitError() << "expected compressed or loose_compressed level "
791  "before singleton level";
792  if (!std::all_of(it, lvlTypes.end(),
793  [](LevelType i) { return isSingletonLT(i); }))
794  return emitError() << "expected all singleton lvlTypes "
795  "following a singleton level";
796  // We can potentially support mixed SoA/AoS singleton levels.
797  if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
798  return it->isa<LevelPropNonDefault::SoA>() ==
799  i.isa<LevelPropNonDefault::SoA>();
800  })) {
801  return emitError() << "expected all singleton lvlTypes stored in the "
802  "same memory layout (SoA vs AoS).";
803  }
804  }
805 
806  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
807  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
808  return emitError() << "Batch lvlType can only be leading levels.";
809 
810  // SoA property can only be applied on singleton level.
811  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
812  return lt.isa<LevelPropNonDefault::SoA>();
813  });
814  if (llvm::any_of(soaLvls, [](LevelType lt) {
815  return !lt.isa<LevelFormat::Singleton>();
816  })) {
817  return emitError() << "SoA is only applicable to singleton lvlTypes.";
818  }
819 
820  // TODO: audit formats that actually are supported by backend.
821  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
822  it != std::end(lvlTypes)) {
823  if (it != lvlTypes.end() - 1)
824  return emitError() << "expected n_out_of_m to be the last level type";
825  if (!std::all_of(lvlTypes.begin(), it,
826  [](LevelType i) { return isDenseLT(i); }))
827  return emitError() << "expected all dense lvlTypes "
828  "before a n_out_of_m level";
829  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
830  if (!isBlockSparsity(dimToLvl)) {
831  return emitError()
832  << "expected 1xm block structure for n_out_of_m level";
833  }
834  auto sizes = getBlockSize(dimToLvl);
835  unsigned coefficient = 0;
836  for (const auto &elem : sizes) {
837  if (elem != 0) {
838  if (elem != coefficient && coefficient != 0) {
839  return emitError() << "expected only one blocked level "
840  "with the same coefficients";
841  }
842  coefficient = elem;
843  }
844  }
845  if (coefficient != getM(*it)) {
846  return emitError() << "expected coeffiencts of Affine expressions "
847  "to be equal to m of n_out_of_m level";
848  }
849  }
850  }
851  // Before we can check that the level-rank is consistent/coherent
852  // across all fields, we need to define it. The source-of-truth for
853  // the `getLvlRank` method is the length of the level-types array,
854  // since it must always be provided and have full rank; therefore we
855  // use that same source-of-truth here.
856  const Level lvlRank = lvlTypes.size();
857  if (lvlRank == 0)
858  return emitError() << "expected a non-empty array for lvlTypes";
859  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
860  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
861  if (dimToLvl) {
862  if (dimToLvl.getNumResults() != lvlRank)
863  return emitError()
864  << "level-rank mismatch between dimToLvl and lvlTypes: "
865  << dimToLvl.getNumResults() << " != " << lvlRank;
866  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
867  // Symbols can't be inferred but are acceptable.
868  if (!inferRes && dimToLvl.getNumSymbols() == 0)
869  return emitError() << "failed to infer lvlToDim from dimToLvl";
870  if (lvlToDim && (inferRes != lvlToDim))
871  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
872  if (dimRank > lvlRank)
873  return emitError() << "unexpected dimToLvl mapping from " << dimRank
874  << " to " << lvlRank;
875  }
876  if (!dimSlices.empty()) {
877  if (dimSlices.size() != dimRank)
878  return emitError()
879  << "dimension-rank mismatch between dimSlices and dimToLvl: "
880  << dimSlices.size() << " != " << dimRank;
881  // Compiler support for `dimSlices` currently requires that the two
882  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
883  if (dimRank != lvlRank)
884  return emitError()
885  << "dimSlices expected dimension-rank to match level-rank: "
886  << dimRank << " != " << lvlRank;
887  }
888  return success();
889 }
890 
891 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
892  ArrayRef<Size> dimShape, Type elementType,
894  // Check structural integrity. In particular, this ensures that the
895  // level-rank is coherent across all the fields.
896  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
897  getPosWidth(), getCrdWidth(), getExplicitVal(),
898  getImplicitVal(), getDimSlices())))
899  return failure();
900  // Check integrity with tensor type specifics. In particular, we
901  // need only check that the dimension-rank of the tensor agrees with
902  // the dimension-rank of the encoding.
903  const Dimension dimRank = dimShape.size();
904  if (dimRank == 0)
905  return emitError() << "expected non-scalar sparse tensor";
906  if (getDimRank() != dimRank)
907  return emitError()
908  << "dimension-rank mismatch between encoding and tensor shape: "
909  << getDimRank() << " != " << dimRank;
910  return success();
911 }
912 
913 //===----------------------------------------------------------------------===//
914 // SparseTensorType Methods.
915 //===----------------------------------------------------------------------===//
916 
918  bool isUnique) const {
919  if (!hasEncoding())
920  return false;
921  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
922  return false;
923  for (Level l = startLvl + 1; l < lvlRank; ++l)
924  if (!isSingletonLvl(l))
925  return false;
926  // If isUnique is true, then make sure that the last level is unique,
927  // that is, when lvlRank == 1, the only compressed level is unique,
928  // and when lvlRank > 1, the last singleton is unique.
929  return !isUnique || isUniqueLvl(lvlRank - 1);
930 }
931 
933  SmallVector<COOSegment> coo = getCOOSegments();
934  assert(coo.size() == 1 || coo.empty());
935  if (!coo.empty() && coo.front().isAoS()) {
936  return coo.front().lvlRange.first;
937  }
938  return lvlRank;
939 }
940 
944  if (!hasEncoding() || lvlRank <= 1)
945  return ret;
946 
947  ArrayRef<LevelType> lts = getLvlTypes();
948  Level l = 0;
949  while (l < lvlRank) {
950  auto lt = lts[l];
952  auto cur = lts.begin() + l;
953  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
954  return !lt.isa<LevelFormat::Singleton>();
955  });
956  unsigned cooLen = std::distance(cur, end);
957  if (cooLen > 1) {
958  // To support mixed SoA/AoS COO, we should break the segment when the
959  // storage scheme changes, for now we faithfully assume that all
960  // consecutive singleton levels have the same storage format as verified
961  // STEA.
962  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
963  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
964  }
965  l += cooLen;
966  } else {
967  l++;
968  }
969  }
970  return ret;
971 }
972 
973 RankedTensorType
975  SmallVector<LevelType> lvlTypes;
976  lvlTypes.reserve(lvlRank);
977  // A non-unique compressed level at beginning (unless this is
978  // also the last level, then it is unique).
979  lvlTypes.push_back(
980  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
981  if (lvlRank > 1) {
982  // Followed by n-2 non-unique singleton levels.
983  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
984  *buildLevelType(LevelFormat::Singleton, ordered, false));
985  // Ends by a unique singleton level.
986  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
987  }
989  getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
990  getCrdWidth(), getExplicitVal(), getImplicitVal());
991  return RankedTensorType::get(getDimShape(), getElementType(), enc);
992 }
993 
994 //===----------------------------------------------------------------------===//
995 // Convenience Methods.
996 //===----------------------------------------------------------------------===//
997 
998 SparseTensorEncodingAttr
1000  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1001  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1002  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1003  return mdtp.getEncoding();
1004  return nullptr;
1005 }
1006 
1008  MLIRContext *context) {
1009  auto map = static_cast<AffineMap>(dimToLvl);
1010  AffineMap lvlToDim;
1011  // Return an empty lvlToDim when inference is not successful.
1012  if (!map || map.getNumSymbols() != 0) {
1013  lvlToDim = AffineMap();
1014  } else if (map.isPermutation()) {
1015  lvlToDim = inversePermutation(map);
1016  } else if (isBlockSparsity(map)) {
1017  lvlToDim = inverseBlockSparsity(map, context);
1018  }
1019  return lvlToDim;
1020 }
1021 
1023  MLIRContext *context) {
1024  SmallVector<AffineExpr> lvlExprs;
1025  auto numLvls = dimToLvl.getNumResults();
1026  lvlExprs.reserve(numLvls);
1027  // lvlExprComponents stores information of the floordiv and mod operations
1028  // applied to the same dimension, so as to build the lvlToDim map.
1029  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1030  for (unsigned i = 0, n = numLvls; i < n; i++) {
1031  auto result = dimToLvl.getResult(i);
1032  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1033  if (result.getKind() == AffineExprKind::FloorDiv) {
1034  // Position of the dimension in dimToLvl.
1035  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1036  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1037  "expected only one floordiv for each dimension");
1038  SmallVector<AffineExpr, 3> components;
1039  // Level variable for floordiv.
1040  components.push_back(getAffineDimExpr(i, context));
1041  // Multiplier.
1042  components.push_back(binOp.getRHS());
1043  // Map key is the position of the dimension.
1044  lvlExprComponents[pos] = components;
1045  } else if (result.getKind() == AffineExprKind::Mod) {
1046  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1047  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1048  "expected floordiv before mod");
1049  // Add level variable for mod to the same vector
1050  // of the corresponding floordiv.
1051  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1052  } else {
1053  assert(false && "expected floordiv or mod");
1054  }
1055  } else {
1056  lvlExprs.push_back(getAffineDimExpr(i, context));
1057  }
1058  }
1059  // Build lvlExprs from lvlExprComponents.
1060  // For example, for il = i floordiv 2 and ii = i mod 2, the components
1061  // would be [il, 2, ii]. It could be used to build the AffineExpr
1062  // i = il * 2 + ii in lvlToDim.
1063  for (auto &components : lvlExprComponents) {
1064  assert(components.second.size() == 3 &&
1065  "expected 3 components to build lvlExprs");
1066  auto mulOp = getAffineBinaryOpExpr(
1067  AffineExprKind::Mul, components.second[0], components.second[1]);
1068  auto addOp =
1069  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1070  lvlExprs.push_back(addOp);
1071  }
1072  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1073 }
1074 
1076  assert(isBlockSparsity(dimToLvl) &&
1077  "expected dimToLvl to be block sparsity for calling getBlockSize");
1078  SmallVector<unsigned> blockSize;
1079  for (auto result : dimToLvl.getResults()) {
1080  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1081  if (result.getKind() == AffineExprKind::Mod) {
1082  blockSize.push_back(
1083  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1084  }
1085  } else {
1086  blockSize.push_back(0);
1087  }
1088  }
1089  return blockSize;
1090 }
1091 
1093  if (!dimToLvl)
1094  return false;
1095  std::map<unsigned, int64_t> coeffientMap;
1096  bool hasBlock = false;
1097  for (auto result : dimToLvl.getResults()) {
1098  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1099  // Check for "dim op const".
1100  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1101  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1102  if (!dimOp || !conOp || conOp.getValue() <= 0)
1103  return false;
1104  // Inspect "dim / const" or "dim % const".
1105  auto pos = dimOp.getPosition();
1106  if (binOp.getKind() == AffineExprKind::FloorDiv) {
1107  // Expect only one floordiv for each dimension.
1108  if (coeffientMap.find(pos) != coeffientMap.end())
1109  return false;
1110  // Record coefficient of the floordiv.
1111  coeffientMap[pos] = conOp.getValue();
1112  } else if (binOp.getKind() == AffineExprKind::Mod) {
1113  // Expect floordiv before mod.
1114  if (coeffientMap.find(pos) == coeffientMap.end())
1115  return false;
1116  // Expect mod to have the same coefficient as floordiv.
1117  if (conOp.getValue() != coeffientMap[pos])
1118  return false;
1119  hasBlock = true;
1120  } else {
1121  return false;
1122  }
1123  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1124  auto pos = dimOp.getPosition();
1125  // Expect dim to be unset.
1126  if (coeffientMap.find(pos) != coeffientMap.end())
1127  return false;
1128  coeffientMap[pos] = 0;
1129  } else {
1130  return false;
1131  }
1132  }
1133  return hasBlock;
1134 }
1135 
1137  auto hasNonIdentityMap = [](Value v) {
1138  auto stt = tryGetSparseTensorType(v);
1139  return stt && !stt->isIdentity();
1140  };
1141 
1142  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1143  llvm::any_of(op->getResults(), hasNonIdentityMap);
1144 }
1145 
1146 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1147  if (enc) {
1148  assert(enc.isPermutation() && "Non permutation map not supported");
1149  if (const auto dimToLvl = enc.getDimToLvl())
1150  return dimToLvl.getDimPosition(l);
1151  }
1152  return l;
1153 }
1154 
1155 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1156  if (enc) {
1157  assert(enc.isPermutation() && "Non permutation map not supported");
1158  if (const auto lvlToDim = enc.getLvlToDim())
1159  return lvlToDim.getDimPosition(d);
1160  }
1161  return d;
1162 }
1163 
1164 /// We normalized sparse tensor encoding attribute by always using
1165 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1166 /// as other variants) lead to the same storage specifier type, and stripping
1167 /// irrelevant fields that do not alter the sparse tensor memory layout.
1168 static SparseTensorEncodingAttr
1169 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1171  for (auto lt : enc.getLvlTypes())
1172  lts.push_back(lt.stripStorageIrrelevantProperties());
1173 
1175  enc.getContext(), lts,
1176  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1177  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1178  // Always use `index` for memSize and lvlSize instead of reusing
1179  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1180  // value for different bitwidth, it also avoids casting between index and
1181  // integer (returned by DimOp)
1182  0, 0,
1183  Attribute(), // explicitVal (irrelevant to storage specifier)
1184  Attribute(), // implicitVal (irrelevant to storage specifier)
1185  enc.getDimSlices());
1186 }
1187 
1188 StorageSpecifierType
1189 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1190  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // SparseTensorDialect Operations.
1195 //===----------------------------------------------------------------------===//
1196 
1198  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1199 }
1200 
1201 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1202  const Type etp = getMemRefType(mem).getElementType();
1203  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1204 }
1205 
1207  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1209  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1210  return op->emitError(
1211  "redundant level argument for querying value memory size");
1212  }
1213 
1214  const auto enc = md.getType().getEncoding();
1215  const Level lvlRank = enc.getLvlRank();
1216 
1217  if (mdKind == StorageSpecifierKind::DimOffset ||
1218  mdKind == StorageSpecifierKind::DimStride)
1219  if (!enc.isSlice())
1220  return op->emitError("requested slice data on non-slice tensor");
1221 
1222  if (mdKind != StorageSpecifierKind::ValMemSize) {
1223  if (!lvl)
1224  return op->emitError("missing level argument");
1225 
1226  const Level l = lvl.value();
1227  if (l >= lvlRank)
1228  return op->emitError("requested level is out of bounds");
1229 
1230  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1231  return op->emitError(
1232  "requested position memory size on a singleton level");
1233  }
1234  return success();
1235 }
1236 
1238  switch (kind) {
1240  return stt.getCrdType();
1242  return stt.getPosType();
1244  return stt.getElementType();
1246  return nullptr;
1247  }
1248  llvm_unreachable("Unrecognizable FieldKind");
1249 }
1250 
1251 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1252  SparseTensorType stt,
1253  RankedTensorType valTp,
1254  TypeRange lvlTps) {
1255  if (requiresStaticShape && !stt.hasStaticDimShape())
1256  return op->emitError("the sparse-tensor must have static shape");
1257  if (!stt.hasEncoding())
1258  return op->emitError("the sparse-tensor must have an encoding attribute");
1259 
1260  // Verifies the trailing COO.
1261  Level cooStartLvl = stt.getAoSCOOStart();
1262  if (cooStartLvl < stt.getLvlRank()) {
1263  // We only supports trailing COO for now, must be the last input.
1264  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1265  // The coordinates should be in shape of <? x rank>
1266  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1267  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1268  op->emitError("input/output trailing COO level-ranks don't match");
1269  }
1270  }
1271 
1272  // Verifies that all types match.
1273  StorageLayout layout(stt.getEncoding());
1274  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1275  return op->emitError("inconsistent number of fields between input/output");
1276 
1277  unsigned idx = 0;
1278  bool misMatch = false;
1279  layout.foreachField([&idx, &misMatch, stt, valTp,
1280  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1281  Level lvl, LevelType lt) -> bool {
1283  return true;
1284 
1285  Type inputTp = nullptr;
1286  if (fKind == SparseTensorFieldKind::ValMemRef) {
1287  inputTp = valTp;
1288  } else {
1289  assert(fid == idx && stt.getLvlType(lvl) == lt);
1290  inputTp = lvlTps[idx++];
1291  }
1292  // The input element type and expected element type should match.
1293  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1294  Type expElemTp = getFieldElemType(stt, fKind);
1295  if (inpElemTp != expElemTp) {
1296  misMatch = true;
1297  return false; // to terminate the iteration
1298  }
1299  return true;
1300  });
1301 
1302  if (misMatch)
1303  return op->emitError("input/output element-types don't match");
1304  return success();
1305 }
1306 
1308  const auto valuesTp = getRankedTensorType(getValues());
1309  const auto lvlsTp = getLevels().getTypes();
1310  const auto resTp = getSparseTensorType(getResult());
1311  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1312 }
1313 
1315  if (getOutValues().getType() != getRetValues().getType())
1316  return emitError("output values and return value type mismatch");
1317 
1318  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1319  if (ot.getType() != rt.getType())
1320  return emitError("output levels and return levels type mismatch");
1321 
1322  const auto valuesTp = getRankedTensorType(getRetValues());
1323  const auto lvlsTp = getRetLevels().getTypes();
1324  const auto srcTp = getSparseTensorType(getTensor());
1325  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1326 }
1327 
1329  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1330  if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1331  if (tp1.getRank() != tp2.getRank())
1332  return emitError("unexpected conversion mismatch in rank");
1333  auto dstEnc =
1334  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1335  if (dstEnc && dstEnc.isSlice())
1336  return emitError("cannot convert to a sparse tensor slice");
1337 
1338  auto shape1 = tp1.getShape();
1339  auto shape2 = tp2.getShape();
1340  // Accept size matches between the source and the destination type
1341  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1342  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1343  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1344  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1345  return emitError("unexpected conversion mismatch in dimension ") << d;
1346  return success();
1347  }
1348  }
1349  return emitError("unexpected type in convert");
1350 }
1351 
1352 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1353  if (getType() == getSource().getType())
1354  return getSource();
1355  return {};
1356 }
1357 
1358 bool ConvertOp::needsExtraSort() {
1359  SparseTensorType srcStt = getSparseTensorType(getSource());
1360  SparseTensorType dstStt = getSparseTensorType(getDest());
1361 
1362  // We do not need an extra sort when returning unordered sparse tensors or
1363  // dense tensor since dense tensor support random access.
1364  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1365  return false;
1366 
1367  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1368  srcStt.hasSameDimToLvl(dstStt)) {
1369  return false;
1370  }
1371 
1372  // Source and dest tensors are ordered in different ways. We only do direct
1373  // dense to sparse conversion when the dense input is defined by a sparse
1374  // constant. Note that we can theoretically always directly convert from dense
1375  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1376  // performance.
1377  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1378  if (isa<SparseElementsAttr>(constOp.getValue()))
1379  return false;
1380 
1381  return true;
1382 }
1383 
1385  uint64_t inRank = getEncoder().getLvlRank();
1386  uint64_t outRank = getEncoder().getDimRank();
1387 
1388  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1389  std::swap(inRank, outRank);
1390 
1391  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1392  return emitError("Coordinate rank mismatch with encoding");
1393 
1394  return success();
1395 }
1396 
1397 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1398  SmallVectorImpl<OpFoldResult> &results) {
1399  if (getEncoder().isIdentity()) {
1400  results.assign(getInCrds().begin(), getInCrds().end());
1401  return success();
1402  }
1403  if (getEncoder().isPermutation()) {
1404  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1405  ? getEncoder().getDimToLvl()
1406  : getEncoder().getLvlToDim();
1407  for (AffineExpr exp : perm.getResults())
1408  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1409  return success();
1410  }
1411 
1412  // Fuse dim2lvl/lvl2dim pairs.
1413  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1414  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1415  return v.getDefiningOp() == def;
1416  });
1417  if (!sameDef)
1418  return failure();
1419 
1420  bool oppositeDir = def.getDirection() != getDirection();
1421  bool sameOracle =
1422  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1423  bool sameCount = def.getNumResults() == getInCrds().size();
1424  if (!oppositeDir || !sameOracle || !sameCount)
1425  return failure();
1426 
1427  // The definition produces the coordinates in the same order as the input
1428  // coordinates.
1429  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1430  [](auto valuePair) {
1431  auto [lhs, rhs] = valuePair;
1432  return lhs == rhs;
1433  });
1434 
1435  if (!sameOrder)
1436  return failure();
1437  // l1 = dim2lvl (lvl2dim l0)
1438  // ==> l0
1439  results.append(def.getInCrds().begin(), def.getInCrds().end());
1440  return success();
1441 }
1442 
1443 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1444  int64_t index) {
1445  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1446  return build(builder, state, source, val);
1447 }
1448 
1450  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1451  auto stt = getSparseTensorType(getSource());
1452  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1453  emitError("Level index exceeds the rank of the input sparse tensor");
1454  }
1455  return success();
1456 }
1457 
1458 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1459  return getConstantIntValue(getIndex());
1460 }
1461 
1462 Speculation::Speculatability LvlOp::getSpeculatability() {
1463  auto constantIndex = getConstantLvlIndex();
1464  if (!constantIndex)
1466 
1467  assert(constantIndex <
1468  cast<RankedTensorType>(getSource().getType()).getRank());
1470 }
1471 
1472 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1473  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1474  if (!lvlIndex)
1475  return {};
1476 
1477  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1478  auto stt = getSparseTensorType(getSource());
1479  if (lvl >= stt.getLvlRank()) {
1480  // Follows the same convention used by tensor.dim operation. Out of bound
1481  // indices produce undefined behavior but are still valid IR. Don't choke on
1482  // them.
1483  return {};
1484  }
1485 
1486  // Helper lambda to build an IndexAttr.
1487  auto getIndexAttr = [this](int64_t lvlSz) {
1488  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1489  };
1490 
1491  SmallVector<Size> lvlShape = stt.getLvlShape();
1492  if (!ShapedType::isDynamic(lvlShape[lvl]))
1493  return getIndexAttr(lvlShape[lvl]);
1494 
1495  return {};
1496 }
1497 
1498 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1499  SparseTensorEncodingAttr dstEnc, Value source) {
1500  auto srcStt = getSparseTensorType(source);
1501  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1502  SmallVector<int64_t> dstDimShape =
1503  dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1504  auto dstTp =
1505  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1506  return build(odsBuilder, odsState, dstTp, source);
1507 }
1508 
1510  auto srcStt = getSparseTensorType(getSource());
1511  auto dstStt = getSparseTensorType(getDest());
1512  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1513  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1514 
1515  if (srcLvlTps.size() != dstLvlTps.size())
1516  return emitError("Level rank mismatch between source/dest tensors");
1517 
1518  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1519  if (srcLvlTp != dstLvlTp)
1520  return emitError("Level type mismatch between source/dest tensors");
1521 
1522  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1523  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1524  return emitError("Crd/Pos width mismatch between source/dest tensors");
1525  }
1526 
1527  if (srcStt.getElementType() != dstStt.getElementType())
1528  return emitError("Element type mismatch between source/dest tensors");
1529 
1530  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1531  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1532  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1533  if (srcLvlSz != dstLvlSz) {
1534  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1535  // compatible to <3x4>? For now, we require all the level sizes to be
1536  // *exactly* matched for simplicity.
1537  return emitError("Level size mismatch between source/dest tensors");
1538  }
1539  }
1540 
1541  return success();
1542 }
1543 
1544 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1545  if (getSource().getType() == getDest().getType())
1546  return getSource();
1547 
1548  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1549  // A -> B, B -> A ==> A
1550  if (def.getSource().getType() == getDest().getType())
1551  return def.getSource();
1552  }
1553  return {};
1554 }
1555 
1556 template <typename ToBufferOp>
1557 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1558  OpaqueProperties prop,
1559  RegionRange region,
1561  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1562  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1563  Type elemTp = nullptr;
1564  bool withStride = false;
1565  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1566  elemTp = stt.getPosType();
1567  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1568  std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1569  elemTp = stt.getCrdType();
1570  if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1571  withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1572  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1573  elemTp = stt.getElementType();
1574  }
1575 
1576  assert(elemTp && "unhandled operation.");
1577  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1578  bufShape.push_back(ShapedType::kDynamic);
1579 
1580  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1581  stt.getContext(), ShapedType::kDynamic,
1582  {ShapedType::kDynamic})
1583  : StridedLayoutAttr();
1584  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1585  return success();
1586 }
1587 
1589  auto stt = getSparseTensorType(getTensor());
1590  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1591  return emitError("requested level is out of bounds");
1592  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1593  return emitError("unexpected type for positions");
1594  return success();
1595 }
1596 
1598 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1599  ValueRange ops, DictionaryAttr attr,
1600  OpaqueProperties prop, RegionRange region,
1602  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1603 }
1604 
1606  auto stt = getSparseTensorType(getTensor());
1607  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1608  return emitError("requested level is out of bounds");
1609  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1610  return emitError("unexpected type for coordinates");
1611  return success();
1612 }
1613 
1615 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1616  ValueRange ops, DictionaryAttr attr,
1617  OpaqueProperties prop, RegionRange region,
1619  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1620 }
1621 
1623  auto stt = getSparseTensorType(getTensor());
1624  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1625  return emitError("expected sparse tensor with a COO region");
1626  return success();
1627 }
1628 
1629 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1630  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1631  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1633  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1634  ret);
1635 }
1636 
1638  auto stt = getSparseTensorType(getTensor());
1639  auto mtp = getMemRefType(getResult());
1640  if (stt.getElementType() != mtp.getElementType())
1641  return emitError("unexpected mismatch in element types");
1642  return success();
1643 }
1644 
1645 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1646  std::optional<Location> loc,
1647  ValueRange ops, DictionaryAttr attr,
1648  OpaqueProperties prop,
1649  RegionRange region,
1651  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1652 }
1653 
1655  auto rank = getRankedTensorType(getSlice()).getRank();
1656  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1657  return emitError("requested dimension out of bound");
1658  return success();
1659 }
1660 
1662  auto rank = getRankedTensorType(getSlice()).getRank();
1663  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1664  return emitError("requested dimension out of bound");
1665  return success();
1666 }
1667 
1669  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1670  getSpecifier(), getOperation());
1671 }
1672 
1673 template <typename SpecifierOp>
1674 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1675  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1676 }
1677 
1678 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1679  const StorageSpecifierKind kind = getSpecifierKind();
1680  const auto lvl = getLevel();
1681  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1682  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1683  return op.getValue();
1684  return {};
1685 }
1686 
1688  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1689  getSpecifier(), getOperation());
1690 }
1691 
1692 template <class T>
1694  const char *regionName,
1695  TypeRange inputTypes, Type outputType) {
1696  unsigned numArgs = region.getNumArguments();
1697  unsigned expectedNum = inputTypes.size();
1698  if (numArgs != expectedNum)
1699  return op->emitError() << regionName << " region must have exactly "
1700  << expectedNum << " arguments";
1701 
1702  for (unsigned i = 0; i < numArgs; i++) {
1703  Type typ = region.getArgument(i).getType();
1704  if (typ != inputTypes[i])
1705  return op->emitError() << regionName << " region argument " << (i + 1)
1706  << " type mismatch";
1707  }
1708  Operation *term = region.front().getTerminator();
1709  YieldOp yield = dyn_cast<YieldOp>(term);
1710  if (!yield)
1711  return op->emitError() << regionName
1712  << " region must end with sparse_tensor.yield";
1713  if (!yield.hasSingleResult() ||
1714  yield.getSingleResult().getType() != outputType)
1715  return op->emitError() << regionName << " region yield type mismatch";
1716 
1717  return success();
1718 }
1719 
1721  NamedAttrList attrs = (*this)->getAttrs();
1722  Type leftType = getX().getType();
1723  Type rightType = getY().getType();
1724  Type outputType = getOutput().getType();
1725  Region &overlap = getOverlapRegion();
1726  Region &left = getLeftRegion();
1727  Region &right = getRightRegion();
1728 
1729  // Check correct number of block arguments and return type for each
1730  // non-empty region.
1731  if (!overlap.empty()) {
1732  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1733  TypeRange{leftType, rightType}, outputType)))
1734  return failure();
1735  }
1736  if (!left.empty()) {
1737  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1738  outputType)))
1739  return failure();
1740  } else if (getLeftIdentity()) {
1741  if (leftType != outputType)
1742  return emitError("left=identity requires first argument to have the same "
1743  "type as the output");
1744  }
1745  if (!right.empty()) {
1746  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1747  outputType)))
1748  return failure();
1749  } else if (getRightIdentity()) {
1750  if (rightType != outputType)
1751  return emitError("right=identity requires second argument to have the "
1752  "same type as the output");
1753  }
1754  return success();
1755 }
1756 
1758  Type inputType = getX().getType();
1759  Type outputType = getOutput().getType();
1760 
1761  // Check correct number of block arguments and return type for each
1762  // non-empty region.
1763  Region &present = getPresentRegion();
1764  if (!present.empty()) {
1765  if (failed(verifyNumBlockArgs(this, present, "present",
1766  TypeRange{inputType}, outputType)))
1767  return failure();
1768  }
1769  Region &absent = getAbsentRegion();
1770  if (!absent.empty()) {
1771  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1772  outputType)))
1773  return failure();
1774  // Absent branch can only yield invariant values.
1775  Block *absentBlock = &absent.front();
1776  Block *parent = getOperation()->getBlock();
1777  Value absentVal =
1778  cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1779  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1780  if (arg.getOwner() == parent)
1781  return emitError("absent region cannot yield linalg argument");
1782  } else if (Operation *def = absentVal.getDefiningOp()) {
1783  if (!isa<arith::ConstantOp>(def) &&
1784  (def->getBlock() == absentBlock || def->getBlock() == parent))
1785  return emitError("absent region cannot yield locally computed value");
1786  }
1787  }
1788  return success();
1789 }
1790 
1791 bool ConcatenateOp::needsExtraSort() {
1792  SparseTensorType dstStt = getSparseTensorType(*this);
1793  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1794  return false;
1795 
1796  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1797  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1798  });
1799  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1800  // in all input/output buffers, and all input/output buffers have the same
1801  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1802  // CSC matrices along column).
1803  bool directLowerable =
1804  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1805  return !directLowerable;
1806 }
1807 
1809  const auto dstTp = getSparseTensorType(*this);
1810  const Dimension concatDim = getDimension();
1811  const Dimension dimRank = dstTp.getDimRank();
1812 
1813  if (getInputs().size() <= 1)
1814  return emitError("Need at least two tensors to concatenate.");
1815 
1816  if (concatDim >= dimRank)
1817  return emitError(llvm::formatv(
1818  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1819  concatDim, dimRank));
1820 
1821  for (const auto &it : llvm::enumerate(getInputs())) {
1822  const auto i = it.index();
1823  const auto srcTp = getSparseTensorType(it.value());
1824  if (srcTp.hasDynamicDimShape())
1825  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1826  const Dimension srcDimRank = srcTp.getDimRank();
1827  if (srcDimRank != dimRank)
1828  return emitError(
1829  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1830  "from the output tensor (rank={2}).",
1831  i, srcDimRank, dimRank));
1832  }
1833 
1834  for (Dimension d = 0; d < dimRank; d++) {
1835  const Size dstSh = dstTp.getDimShape()[d];
1836  if (d == concatDim) {
1837  if (!ShapedType::isDynamic(dstSh)) {
1838  // If we reach here, then all inputs have static shapes. So we
1839  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1840  // to avoid redundant assertions in the loop.
1841  Size sumSz = 0;
1842  for (const auto src : getInputs())
1843  sumSz += getSparseTensorType(src).getDimShape()[d];
1844  // If all dimension are statically known, the sum of all the input
1845  // dimensions should be equal to the output dimension.
1846  if (sumSz != dstSh)
1847  return emitError(
1848  "The concatenation dimension of the output tensor should be the "
1849  "sum of all the concatenation dimensions of the input tensors.");
1850  }
1851  } else {
1852  Size prev = dstSh;
1853  for (const auto src : getInputs()) {
1854  const auto sh = getSparseTensorType(src).getDimShape()[d];
1855  if (!ShapedType::isDynamic(prev) && sh != prev)
1856  return emitError("All dimensions (expect for the concatenating one) "
1857  "should be equal.");
1858  prev = sh;
1859  }
1860  }
1861  }
1862 
1863  return success();
1864 }
1865 
1866 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1867  Value curSize, Value inBuffer, Value value) {
1868  build(builder, result, curSize, inBuffer, value, Value());
1869 }
1870 
1872  if (Value n = getN()) {
1873  std::optional<int64_t> nValue = getConstantIntValue(n);
1874  if (nValue && nValue.value() < 1)
1875  return emitOpError("n must be not less than 1");
1876  }
1877  return success();
1878 }
1879 
1881  const auto stt = getSparseTensorType(getTensor());
1882  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1883  return emitOpError("incorrect number of coordinates");
1884  return success();
1885 }
1886 
1887 void ForeachOp::build(
1888  OpBuilder &builder, OperationState &result, Value tensor,
1889  ValueRange initArgs, AffineMapAttr order,
1891  bodyBuilder) {
1892  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1893  // Builds foreach body.
1894  if (!bodyBuilder)
1895  return;
1896  const auto stt = getSparseTensorType(tensor);
1897  const Dimension dimRank = stt.getDimRank();
1898 
1899  // Starts with `dimRank`-many coordinates.
1900  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1901  // Followed by one value.
1902  blockArgTypes.push_back(stt.getElementType());
1903  // Followed by the reduction variables.
1904  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1905 
1906  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1907 
1908  OpBuilder::InsertionGuard guard(builder);
1909  auto &region = *result.regions.front();
1910  Block *bodyBlock =
1911  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1912  bodyBuilder(builder, result.location,
1913  bodyBlock->getArguments().slice(0, dimRank),
1914  bodyBlock->getArguments()[dimRank],
1915  bodyBlock->getArguments().drop_front(dimRank + 1));
1916 }
1917 
1919  const auto t = getSparseTensorType(getTensor());
1920  const Dimension dimRank = t.getDimRank();
1921  const auto args = getBody()->getArguments();
1922 
1923  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1924  return emitError("Level traverse order does not match tensor's level rank");
1925 
1926  if (dimRank + 1 + getInitArgs().size() != args.size())
1927  return emitError("Unmatched number of arguments in the block");
1928 
1929  if (getNumResults() != getInitArgs().size())
1930  return emitError("Mismatch in number of init arguments and results");
1931 
1932  if (getResultTypes() != getInitArgs().getTypes())
1933  return emitError("Mismatch in types of init arguments and results");
1934 
1935  // Cannot mark this const, because the getters aren't.
1936  auto yield = cast<YieldOp>(getBody()->getTerminator());
1937  if (yield.getNumOperands() != getNumResults() ||
1938  yield.getOperands().getTypes() != getResultTypes())
1939  return emitError("Mismatch in types of yield values and results");
1940 
1941  const auto iTp = IndexType::get(getContext());
1942  for (Dimension d = 0; d < dimRank; d++)
1943  if (args[d].getType() != iTp)
1944  emitError(
1945  llvm::formatv("Expecting Index type for argument at index {0}", d));
1946 
1947  const auto elemTp = t.getElementType();
1948  const auto valueTp = args[dimRank].getType();
1949  if (elemTp != valueTp)
1950  emitError(llvm::formatv("Unmatched element type between input tensor and "
1951  "block argument, expected:{0}, got: {1}",
1952  elemTp, valueTp));
1953  return success();
1954 }
1955 
1956 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1957  if (getSparseTensorEncoding(getInputCoo().getType()) ==
1958  getSparseTensorEncoding(getResultCoo().getType()))
1959  return getInputCoo();
1960 
1961  return {};
1962 }
1963 
1965  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1966  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1967 
1968  if (!srcStt.isCOOType() || !dstStt.isCOOType())
1969  emitError("Expected COO sparse tensors only");
1970 
1971  if (!srcStt.hasSameDimToLvl(dstStt))
1972  emitError("Unmatched dim2lvl map between input and result COO");
1973 
1974  if (srcStt.getPosType() != dstStt.getPosType() ||
1975  srcStt.getCrdType() != dstStt.getCrdType() ||
1976  srcStt.getElementType() != dstStt.getElementType())
1977  emitError("Unmatched storage format between input and result COO");
1978 
1979  return success();
1980 }
1981 
1983  Type inputType = getX().getType();
1984  Region &formula = getRegion();
1985  return verifyNumBlockArgs(this, formula, "reduce",
1986  TypeRange{inputType, inputType}, inputType);
1987 }
1988 
1990  Builder b(getContext());
1991  Type inputType = getX().getType();
1992  Type boolType = b.getI1Type();
1993  Region &formula = getRegion();
1994  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1995  boolType);
1996 }
1997 
1999  AffineMap xPerm = getPermMap();
2000  uint64_t nx = xPerm.getNumDims();
2001  if (nx < 1)
2002  emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2003 
2004  if (!xPerm.isPermutation())
2005  emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2006 
2007  // We can't check the size of the buffers when n or buffer dimensions aren't
2008  // compile-time constants.
2009  std::optional<int64_t> cn = getConstantIntValue(getN());
2010  if (!cn)
2011  return success();
2012 
2013  // Verify dimensions.
2014  const auto checkDim = [&](Value v, Size minSize, const char *message) {
2015  const Size sh = getMemRefType(v).getShape()[0];
2016  if (!ShapedType::isDynamic(sh) && sh < minSize)
2017  emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2018  };
2019  uint64_t n = cn.value();
2020  uint64_t ny = 0;
2021  if (auto nyAttr = getNyAttr())
2022  ny = nyAttr.getInt();
2023  checkDim(getXy(), n * (nx + ny),
2024  "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2025  for (Value opnd : getYs())
2026  checkDim(opnd, n, "Expected dimension(y) >= n");
2027 
2028  return success();
2029 }
2030 
2031 //===----------------------------------------------------------------------===//
2032 // Sparse Tensor Iteration Operations.
2033 //===----------------------------------------------------------------------===//
2034 
2035 IterSpaceType IteratorType::getIterSpaceType() const {
2036  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2037  getHiLvl());
2038 }
2039 
2040 IteratorType IterSpaceType::getIteratorType() const {
2041  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2042 }
2043 
2044 /// Parses a level range in the form "$lo `to` $hi"
2045 /// or simply "$lo" if $hi - $lo = 1
2047  Level &lvlHi) {
2048  if (parser.parseInteger(lvlLo))
2049  return failure();
2050 
2051  if (succeeded(parser.parseOptionalKeyword("to"))) {
2052  if (parser.parseInteger(lvlHi))
2053  return failure();
2054  } else {
2055  lvlHi = lvlLo + 1;
2056  }
2057 
2058  if (lvlHi <= lvlLo)
2059  parser.emitError(parser.getNameLoc(),
2060  "expect larger level upper bound than lower bound");
2061 
2062  return success();
2063 }
2064 
2065 /// Parses a level range in the form "$lo `to` $hi"
2066 /// or simply "$lo" if $hi - $lo = 1
2067 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2068  IntegerAttr &lvlHiAttr) {
2069  Level lvlLo, lvlHi;
2070  if (parseLevelRange(parser, lvlLo, lvlHi))
2071  return failure();
2072 
2073  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2074  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2075  return success();
2076 }
2077 
2078 /// Prints a level range in the form "$lo `to` $hi"
2079 /// or simply "$lo" if $hi - $lo = 1
2080 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2081 
2082  if (lo + 1 == hi)
2083  p << lo;
2084  else
2085  p << lo << " to " << hi;
2086 }
2087 
2088 /// Prints a level range in the form "$lo `to` $hi"
2089 /// or simply "$lo" if $hi - $lo = 1
2090 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2091  IntegerAttr lvlHi) {
2092  unsigned lo = lvlLo.getValue().getZExtValue();
2093  unsigned hi = lvlHi.getValue().getZExtValue();
2094  printLevelRange(p, lo, hi);
2095 }
2096 
2097 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2098  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2099  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2101 
2102  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2103  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2104  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2105  adaptor.getHiLvl()));
2106  return success();
2107 }
2108 
2110  if (getLoLvl() >= getHiLvl())
2111  return emitOpError("expected smaller level low than level high");
2112 
2113  TypedValue<IteratorType> pIter = getParentIter();
2114  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2115  return emitOpError(
2116  "parent iterator should be specified iff level lower bound equals 0");
2117  }
2118 
2119  if (pIter) {
2120  IterSpaceType spaceTp = getResultSpace().getType();
2121  if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2122  return emitOpError(
2123  "mismatch in parent iterator encoding and iteration space encoding.");
2124 
2125  if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2126  return emitOpError("parent iterator should be used to extract an "
2127  "iteration space from a consecutive level.");
2128  }
2129 
2130  return success();
2131 }
2132 
2133 /// Materialize a single constant operation from a given attribute value with
2134 /// the desired resultant type.
2136  Attribute value, Type type,
2137  Location loc) {
2138  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2139  return op;
2140  return nullptr;
2141 }
2142 
2143 namespace {
2144 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2146 
2147  AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2148  if (isa<SparseTensorEncodingAttr>(attr)) {
2149  os << "sparse";
2150  return AliasResult::OverridableAlias;
2151  }
2152  return AliasResult::NoAlias;
2153  }
2154 };
2155 } // namespace
2156 
2157 void SparseTensorDialect::initialize() {
2158  addInterface<SparseTensorAsmDialectInterface>();
2159  addAttributes<
2160 #define GET_ATTRDEF_LIST
2161 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2162  >();
2163  addTypes<
2164 #define GET_TYPEDEF_LIST
2165 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2166  >();
2167  addOperations<
2168 #define GET_OP_LIST
2169 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2170  >();
2171  declarePromisedInterfaces<
2172  bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2173  NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2174  ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2175 }
2176 
2177 #define GET_OP_CLASSES
2178 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2179 
2180 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:112
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:69
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:399
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:353
unsigned getNumSymbols() const
Definition: AffineMap.cpp:382
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:395
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:609
The possible results of an alias query.
Definition: AliasAnalysis.h:26
@ NoAlias
The two locations do not alias at all.
Definition: AliasAnalysis.h:34
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IndexType getIndexType()
Definition: Builders.cpp:71
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
SmallVector< COOSegment > getCOOSegments() const
Returns a list of COO segments in the sparse tensor types.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:427
bool isWithPosLT(LevelType lt)
Definition: Enums.h:428
bool isOrderedLT(LevelType lt)
Definition: Enums.h:421
std::string toMLIRString(LevelType lt)
Definition: Enums.h:443
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:417
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
bool isCompressedLT(LevelType lt)
Definition: Enums.h:411
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:438
bool isLooseCompressedLT(LevelType lt)
Definition: Enums.h:414
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:42
llvm::hash_code hash_value(LevelType lt)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:74
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:82
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:409
uint64_t getM(LevelType lt)
Definition: Enums.h:439
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:410
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:398
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:420
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:62
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299