MLIR  20.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53  return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62  switch (bitWidth) {
63  case 0:
64  case 8:
65  case 16:
66  case 32:
67  case 64:
68  return true;
69  default:
70  return false;
71  }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76  std::optional<ArrayRef<int64_t>> dimShape) {
77  assert(enc);
78  // With only encoding, we can not determine the static shape for leading
79  // batch levels, we therefore return a dynamic shape memref instead.
80  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81  if (dimShape.has_value()) {
82  // If the actual tensor shape is provided, we can then refine the leading
83  // batch dimension.
84  SmallVector<int64_t> lvlShape =
85  enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86  memrefShape.assign(lvlShape.begin(),
87  lvlShape.begin() + enc.getBatchLvlRank());
88  }
89  // Another dynamic dimension to store the sparse level.
90  memrefShape.push_back(ShapedType::kDynamic);
91  return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
104  LevelType)>
105  callback) const {
106  const auto lvlTypes = enc.getLvlTypes();
107  const Level lvlRank = enc.getLvlRank();
108  SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
110 
111  ArrayRef cooSegsRef = cooSegs;
112  // Per-level storage.
113  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114  const auto lt = lvlTypes[l];
115  if (isWithPosLT(lt)) {
116  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117  return;
118  }
119  if (isWithCrdLT(lt)) {
120  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121  return;
122  }
123  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124  if (!cooSegsRef.front().isSoA) {
125  // AoS COO, all singletons are fused into one memrefs. Skips the entire
126  // COO segement.
127  l = cooSegsRef.front().lvlRange.second;
128  } else {
129  // SoA COO, each singleton level has one memref.
130  l++;
131  }
132  // Expire handled COO segment.
133  cooSegsRef = cooSegsRef.drop_front();
134  } else {
135  // Non COO levels.
136  l++;
137  }
138  }
139  // The values array.
140  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
142  return;
143  // Put metadata at the end.
144  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
146  return;
147 }
148 
150  SparseTensorType stt,
152  LevelType)>
153  callback) {
154  assert(stt.hasEncoding());
155 
156  SmallVector<int64_t> memrefShape =
158 
159  const Type specType = StorageSpecifierType::get(stt.getEncoding());
160  // memref<[batch] x ? x pos> positions
161  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162  // memref<[batch] x ? x crd> coordinates
163  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164  // memref<[batch] x ? x eltType> values
165  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168  callback](FieldIndex fieldIdx,
169  SparseTensorFieldKind fieldKind,
170  Level lvl, LevelType lt) -> bool {
171  switch (fieldKind) {
173  return callback(specType, fieldIdx, fieldKind, lvl, lt);
175  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
177  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
179  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180  };
181  llvm_unreachable("unrecognized field kind");
182  });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186  unsigned numFields = 0;
188  LevelType) -> bool {
189  numFields++;
190  return true;
191  });
192  return numFields;
193 }
194 
196  unsigned numFields = 0; // one value memref
198  LevelType) -> bool {
199  if (fidx >= kDataFieldStartingIdx)
200  numFields++;
201  return true;
202  });
203  numFields -= 1; // the last field is StorageSpecifier
204  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205  return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
210  std::optional<Level> lvl) const {
211  FieldIndex fieldIdx = kInvalidFieldIndex;
212  unsigned stride = 1;
213  if (kind == SparseTensorFieldKind::CrdMemRef) {
214  assert(lvl.has_value());
215  const Level cooStart = enc.getAoSCOOStart();
216  const Level lvlRank = enc.getLvlRank();
217  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218  lvl = cooStart;
219  stride = lvlRank - cooStart;
220  }
221  }
222  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223  SparseTensorFieldKind fKind, Level fLvl,
224  LevelType lt) -> bool {
225  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227  fieldIdx = fIdx;
228  // Returns false to break the iteration.
229  return false;
230  }
231  return true;
232  });
233  assert(fieldIdx != kInvalidFieldIndex);
234  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242  return isDynamic(v) ? std::nullopt
243  : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247  return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251  return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255  return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259  return isDynamic(getOffset()) && isDynamic(getStride()) &&
260  isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264  return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269  os << '(';
270  os << getStaticString(getOffset());
271  os << ", ";
272  os << getStaticString(getSize());
273  os << ", ";
274  os << getStaticString(getStride());
275  os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279  print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283  AsmParser &parser) {
284  auto parseResult = parser.parseOptionalInteger(result);
285  if (parseResult.has_value()) {
286  if (parseResult.value().succeeded() && result < 0) {
287  parser.emitError(
288  parser.getCurrentLocation(),
289  "expect positive value or ? for slice offset/size/stride");
290  return failure();
291  }
292  return parseResult.value();
293  }
294 
295  // Else, and '?' which represented dynamic slice
296  result = SparseTensorDimSliceAttr::kDynamic;
297  return parser.parseQuestion();
298 }
299 
301  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303  if (failed(parser.parseLParen()) ||
304  failed(parseOptionalStaticSlice(offset, parser)) ||
305  failed(parser.parseComma()) ||
306  failed(parseOptionalStaticSlice(size, parser)) ||
307  failed(parser.parseComma()) ||
308  failed(parseOptionalStaticSlice(stride, parser)) ||
309  failed(parser.parseRParen()))
310  return {};
311 
312  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313  offset, size, stride);
314 }
315 
316 LogicalResult
318  int64_t offset, int64_t size, int64_t stride) {
319  if (!isDynamic(offset) && offset < 0)
320  return emitError() << "expect non-negative value or ? for slice offset";
321  if (!isDynamic(size) && size <= 0)
322  return emitError() << "expect positive value or ? for slice size";
323  if (!isDynamic(stride) && stride <= 0)
324  return emitError() << "expect positive value or ? for slice stride";
325  return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
332  getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333  getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342  return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347  unsigned crdWidth) const {
348  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
350  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351  crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355  return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
362  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363  getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367  return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
374  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375  getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379  return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
385  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386  getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394  ArrayRef<LevelType> lvlTypes = getLvlTypes();
395  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396  return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
400  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408  if (!getImpl())
409  return nullptr;
410  if (getCrdWidth())
411  return IntegerType::get(getContext(), getCrdWidth());
412  return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416  if (!getImpl())
417  return nullptr;
418  if (getPosWidth())
419  return IntegerType::get(getContext(), getPosWidth());
420  return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424  std::optional<ArrayRef<int64_t>> dimShape) const {
425  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426  return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430  std::optional<ArrayRef<int64_t>> dimShape) const {
431  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432  return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
440  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445  const auto dimToLvl = getDimToLvl();
446  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451  return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455  if (!getImpl())
456  return LevelFormat::Batch;
457  assert(l < getLvlRank() && "Level is out of bounds");
458  return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463  return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468  assert(isSlice() && "Is not a slice");
469  const auto dimSlices = getDimSlices();
470  assert(dim < dimSlices.size() && "Dimension is out of bounds");
471  return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476  return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481  return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486  return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491  return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496  CrdTransDirectionKind dir) const {
497  if (isIdentity())
498  return SmallVector<int64_t>(srcShape);
499 
501  unsigned rank =
502  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503  ret.reserve(rank);
504 
505  if (isPermutation()) {
506  for (unsigned r = 0; r < rank; r++) {
507  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508  : toLvl(*this, r);
509  ret.push_back(srcShape[trans]);
510  }
511  return ret;
512  }
513 
514  // Handle non-permutation maps.
515  AffineMap transMap =
516  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
519  dimRep.reserve(srcShape.size());
520  for (int64_t sz : srcShape) {
521  if (!ShapedType::isDynamic(sz)) {
522  // Push back the max coordinate for the given dimension/level size.
523  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524  } else {
525  // A dynamic size, use a AffineDimExpr to symbolize the value.
526  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527  }
528  };
529 
530  for (AffineExpr exp : transMap.getResults()) {
531  // Do constant propagation on the affine map.
532  AffineExpr evalExp =
533  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534  // use llvm namespace here to avoid ambiguity
535  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536  ret.push_back(c.getValue() + 1);
537  } else {
538  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539  mod && mod.getKind() == AffineExprKind::Mod) {
540  // We can still infer a static bound for expressions in form
541  // "d % constant" since d % constant \in [0, constant).
542  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543  ret.push_back(bound.getValue());
544  continue;
545  }
546  }
547  ret.push_back(ShapedType::kDynamic);
548  }
549  }
550  assert(ret.size() == rank);
551  return ret;
552 }
553 
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556  ValueRange crds,
557  CrdTransDirectionKind dir) const {
558  if (!getImpl())
559  return crds;
560 
561  SmallVector<Type> retType(
562  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563  builder.getIndexType());
564  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565  return transOp.getOutCrds();
566 }
567 
569  // Open "<{" part.
570  if (failed(parser.parseLess()))
571  return {};
572  if (failed(parser.parseLBrace()))
573  return {};
574 
575  // Process the data from the parsed dictionary value into struct-like data.
576  SmallVector<LevelType> lvlTypes;
578  AffineMap dimToLvl = {};
579  AffineMap lvlToDim = {};
580  unsigned posWidth = 0;
581  unsigned crdWidth = 0;
582  Attribute explicitVal;
583  Attribute implicitVal;
584  StringRef attrName;
585  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586  "explicitVal", "implicitVal"};
587  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588  // Detect admissible keyword.
589  auto *it = find(keys, attrName);
590  if (it == keys.end()) {
591  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592  return {};
593  }
594  unsigned keyWordIndex = it - keys.begin();
595  // Consume the `=` after keys
596  if (failed(parser.parseEqual()))
597  return {};
598  // Dispatch on keyword.
599  switch (keyWordIndex) {
600  case 0: { // map
601  ir_detail::DimLvlMapParser cParser(parser);
602  auto res = cParser.parseDimLvlMap();
603  if (failed(res))
604  return {};
605  const auto &dlm = *res;
606 
607  const Level lvlRank = dlm.getLvlRank();
608  for (Level lvl = 0; lvl < lvlRank; lvl++)
609  lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611  const Dimension dimRank = dlm.getDimRank();
612  for (Dimension dim = 0; dim < dimRank; dim++)
613  dimSlices.push_back(dlm.getDimSlice(dim));
614  // NOTE: the old syntax requires an all-or-nothing approach to
615  // `dimSlices`; therefore, if any slice actually exists then we need
616  // to convert null-DSA into default/nop DSA.
617  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618  return static_cast<bool>(slice.getImpl());
619  };
620  if (llvm::any_of(dimSlices, isDefined)) {
621  const auto defaultSlice =
623  for (Dimension dim = 0; dim < dimRank; dim++)
624  if (!isDefined(dimSlices[dim]))
625  dimSlices[dim] = defaultSlice;
626  } else {
627  dimSlices.clear();
628  }
629 
630  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632  break;
633  }
634  case 1: { // posWidth
635  Attribute attr;
636  if (failed(parser.parseAttribute(attr)))
637  return {};
638  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639  if (!intAttr) {
640  parser.emitError(parser.getNameLoc(),
641  "expected an integral position bitwidth");
642  return {};
643  }
644  posWidth = intAttr.getInt();
645  break;
646  }
647  case 2: { // crdWidth
648  Attribute attr;
649  if (failed(parser.parseAttribute(attr)))
650  return {};
651  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652  if (!intAttr) {
653  parser.emitError(parser.getNameLoc(),
654  "expected an integral index bitwidth");
655  return {};
656  }
657  crdWidth = intAttr.getInt();
658  break;
659  }
660  case 3: { // explicitVal
661  Attribute attr;
662  if (failed(parser.parseAttribute(attr)))
663  return {};
664  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665  explicitVal = result;
666  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667  explicitVal = result;
668  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669  explicitVal = result;
670  } else {
671  parser.emitError(parser.getNameLoc(),
672  "expected a numeric value for explicitVal");
673  return {};
674  }
675  break;
676  }
677  case 4: { // implicitVal
678  Attribute attr;
679  if (failed(parser.parseAttribute(attr)))
680  return {};
681  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682  implicitVal = result;
683  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684  implicitVal = result;
685  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686  implicitVal = result;
687  } else {
688  parser.emitError(parser.getNameLoc(),
689  "expected a numeric value for implicitVal");
690  return {};
691  }
692  break;
693  }
694  } // switch
695  // Only last item can omit the comma.
696  if (parser.parseOptionalComma().failed())
697  break;
698  }
699 
700  // Close "}>" part.
701  if (failed(parser.parseRBrace()))
702  return {};
703  if (failed(parser.parseGreater()))
704  return {};
705 
706  // Construct struct-like storage for attribute.
707  if (!lvlToDim || lvlToDim.isEmpty()) {
708  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709  }
710  return parser.getChecked<SparseTensorEncodingAttr>(
711  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712  explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716  auto map = static_cast<AffineMap>(getDimToLvl());
717  // Empty affine map indicates identity map
718  if (!map)
719  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720  printer << "<{ map = ";
721  printSymbols(map, printer);
722  printer << '(';
723  printDimensions(map, printer, getDimSlices());
724  printer << ") -> (";
725  printLevels(map, printer, getLvlTypes());
726  printer << ')';
727  // Print remaining members only for non-default values.
728  if (getPosWidth())
729  printer << ", posWidth = " << getPosWidth();
730  if (getCrdWidth())
731  printer << ", crdWidth = " << getCrdWidth();
732  if (getExplicitVal()) {
733  printer << ", explicitVal = " << getExplicitVal();
734  }
735  if (getImplicitVal())
736  printer << ", implicitVal = " << getImplicitVal();
737  printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741  AsmPrinter &printer) const {
742  if (map.getNumSymbols() == 0)
743  return;
744  printer << '[';
745  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746  printer << 's' << i << ", ";
747  if (map.getNumSymbols() >= 1)
748  printer << 's' << map.getNumSymbols() - 1;
749  printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753  AffineMap &map, AsmPrinter &printer,
754  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755  if (!dimSlices.empty()) {
756  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757  printer << 'd' << i << " : " << dimSlices[i] << ", ";
758  if (map.getNumDims() >= 1) {
759  printer << 'd' << map.getNumDims() - 1 << " : "
760  << dimSlices[map.getNumDims() - 1];
761  }
762  } else {
763  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764  printer << 'd' << i << ", ";
765  if (map.getNumDims() >= 1)
766  printer << 'd' << map.getNumDims() - 1;
767  }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771  ArrayRef<LevelType> lvlTypes) const {
772  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773  map.getResult(i).print(printer.getStream());
774  printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775  }
776  if (map.getNumResults() >= 1) {
777  auto lastIndex = map.getNumResults() - 1;
778  map.getResult(lastIndex).print(printer.getStream());
779  printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780  }
781 }
782 
785  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786  unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
788  if (!acceptBitWidth(posWidth))
789  return emitError() << "unexpected position bitwidth: " << posWidth;
790  if (!acceptBitWidth(crdWidth))
791  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793  // Verify every COO segment.
794  auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
795  while (it != lvlTypes.end()) {
796  if (it == lvlTypes.begin() ||
798  return emitError() << "expected compressed or loose_compressed level "
799  "before singleton level";
800 
801  auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802  if (!std::all_of(it, curCOOEnd,
803  [](LevelType i) { return isSingletonLT(i); }))
804  return emitError() << "expected all singleton lvlTypes "
805  "following a singleton level";
806  // We can potentially support mixed SoA/AoS singleton levels.
807  if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808  return it->isa<LevelPropNonDefault::SoA>() ==
810  })) {
811  return emitError() << "expected all singleton lvlTypes stored in the "
812  "same memory layout (SoA vs AoS).";
813  }
814  it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815  }
816 
817  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
818  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
819  return emitError() << "Batch lvlType can only be leading levels.";
820 
821  // SoA property can only be applied on singleton level.
822  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823  return lt.isa<LevelPropNonDefault::SoA>();
824  });
825  if (llvm::any_of(soaLvls, [](LevelType lt) {
826  return !lt.isa<LevelFormat::Singleton>();
827  })) {
828  return emitError() << "SoA is only applicable to singleton lvlTypes.";
829  }
830 
831  // TODO: audit formats that actually are supported by backend.
832  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
833  it != std::end(lvlTypes)) {
834  if (it != lvlTypes.end() - 1)
835  return emitError() << "expected n_out_of_m to be the last level type";
836  if (!std::all_of(lvlTypes.begin(), it,
837  [](LevelType i) { return isDenseLT(i); }))
838  return emitError() << "expected all dense lvlTypes "
839  "before a n_out_of_m level";
840  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
841  if (!isBlockSparsity(dimToLvl)) {
842  return emitError()
843  << "expected 1xm block structure for n_out_of_m level";
844  }
845  auto sizes = getBlockSize(dimToLvl);
846  unsigned coefficient = 0;
847  for (const auto &elem : sizes) {
848  if (elem != 0) {
849  if (elem != coefficient && coefficient != 0) {
850  return emitError() << "expected only one blocked level "
851  "with the same coefficients";
852  }
853  coefficient = elem;
854  }
855  }
856  if (coefficient != getM(*it)) {
857  return emitError() << "expected coeffiencts of Affine expressions "
858  "to be equal to m of n_out_of_m level";
859  }
860  }
861  }
862  // Before we can check that the level-rank is consistent/coherent
863  // across all fields, we need to define it. The source-of-truth for
864  // the `getLvlRank` method is the length of the level-types array,
865  // since it must always be provided and have full rank; therefore we
866  // use that same source-of-truth here.
867  const Level lvlRank = lvlTypes.size();
868  if (lvlRank == 0)
869  return emitError() << "expected a non-empty array for lvlTypes";
870  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
871  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
872  if (dimToLvl) {
873  if (dimToLvl.getNumResults() != lvlRank)
874  return emitError()
875  << "level-rank mismatch between dimToLvl and lvlTypes: "
876  << dimToLvl.getNumResults() << " != " << lvlRank;
877  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
878  // Symbols can't be inferred but are acceptable.
879  if (!inferRes && dimToLvl.getNumSymbols() == 0)
880  return emitError() << "failed to infer lvlToDim from dimToLvl";
881  if (lvlToDim && (inferRes != lvlToDim))
882  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
883  if (dimRank > lvlRank)
884  return emitError() << "unexpected dimToLvl mapping from " << dimRank
885  << " to " << lvlRank;
886  }
887  if (!dimSlices.empty()) {
888  if (dimSlices.size() != dimRank)
889  return emitError()
890  << "dimension-rank mismatch between dimSlices and dimToLvl: "
891  << dimSlices.size() << " != " << dimRank;
892  // Compiler support for `dimSlices` currently requires that the two
893  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
894  if (dimRank != lvlRank)
895  return emitError()
896  << "dimSlices expected dimension-rank to match level-rank: "
897  << dimRank << " != " << lvlRank;
898  }
899  return success();
900 }
901 
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903  ArrayRef<Size> dimShape, Type elementType,
905  // Check structural integrity. In particular, this ensures that the
906  // level-rank is coherent across all the fields.
907  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908  getPosWidth(), getCrdWidth(), getExplicitVal(),
909  getImplicitVal(), getDimSlices())))
910  return failure();
911  // Check integrity with tensor type specifics. In particular, we
912  // need only check that the dimension-rank of the tensor agrees with
913  // the dimension-rank of the encoding.
914  const Dimension dimRank = dimShape.size();
915  if (dimRank == 0)
916  return emitError() << "expected non-scalar sparse tensor";
917  if (getDimRank() != dimRank)
918  return emitError()
919  << "dimension-rank mismatch between encoding and tensor shape: "
920  << getDimRank() << " != " << dimRank;
921  if (auto expVal = getExplicitVal()) {
922  Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923  if (attrType != elementType) {
924  return emitError() << "explicit value type mismatch between encoding and "
925  << "tensor element type: " << attrType
926  << " != " << elementType;
927  }
928  }
929  if (auto impVal = getImplicitVal()) {
930  Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931  if (attrType != elementType) {
932  return emitError() << "implicit value type mismatch between encoding and "
933  << "tensor element type: " << attrType
934  << " != " << elementType;
935  }
936  // Currently, we only support zero as the implicit value.
937  auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938  auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939  auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940  if ((impFVal && impFVal.getValue().isNonZero()) ||
941  (impIntVal && !impIntVal.getValue().isZero()) ||
942  (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943  impComplexVal.getReal().isNonZero()))) {
944  return emitError() << "implicit value must be zero";
945  }
946  }
947  return success();
948 }
949 
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951  SmallVector<COOSegment> coo = getCOOSegments();
952  assert(coo.size() == 1 || coo.empty());
953  if (!coo.empty() && coo.front().isAoS()) {
954  return coo.front().lvlRange.first;
955  }
956  return getLvlRank();
957 }
958 
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
962  if (getLvlRank() <= 1)
963  return ret;
964 
965  ArrayRef<LevelType> lts = getLvlTypes();
966  Level l = 0;
967  while (l < getLvlRank()) {
968  auto lt = lts[l];
970  auto cur = lts.begin() + l;
971  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972  return !lt.isa<LevelFormat::Singleton>();
973  });
974  unsigned cooLen = std::distance(cur, end);
975  if (cooLen > 1) {
976  // To support mixed SoA/AoS COO, we should break the segment when the
977  // storage scheme changes, for now we faithfully assume that all
978  // consecutive singleton levels have the same storage format as verified
979  // STEA.
980  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982  }
983  l += cooLen;
984  } else {
985  l++;
986  }
987  }
988  return ret;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // SparseTensorType Methods.
993 //===----------------------------------------------------------------------===//
994 
996  bool isUnique) const {
997  if (!hasEncoding())
998  return false;
999  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1000  return false;
1001  for (Level l = startLvl + 1; l < lvlRank; ++l)
1002  if (!isSingletonLvl(l))
1003  return false;
1004  // If isUnique is true, then make sure that the last level is unique,
1005  // that is, when lvlRank == 1, the only compressed level is unique,
1006  // and when lvlRank > 1, the last singleton is unique.
1007  return !isUnique || isUniqueLvl(lvlRank - 1);
1008 }
1009 
1010 RankedTensorType
1012  SmallVector<LevelType> lvlTypes;
1013  lvlTypes.reserve(lvlRank);
1014  // A non-unique compressed level at beginning (unless this is
1015  // also the last level, then it is unique).
1016  lvlTypes.push_back(
1017  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1018  if (lvlRank > 1) {
1019  // Followed by n-2 non-unique singleton levels.
1020  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1021  *buildLevelType(LevelFormat::Singleton, ordered, false));
1022  // Ends by a unique singleton level.
1023  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1024  }
1025  auto enc = SparseTensorEncodingAttr::get(
1026  getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027  getCrdWidth(), getExplicitVal(), getImplicitVal());
1028  return RankedTensorType::get(getDimShape(), getElementType(), enc);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // Convenience Methods.
1033 //===----------------------------------------------------------------------===//
1034 
1035 SparseTensorEncodingAttr
1037  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040  return mdtp.getEncoding();
1041  return nullptr;
1042 }
1043 
1045  MLIRContext *context) {
1046  auto map = static_cast<AffineMap>(dimToLvl);
1047  AffineMap lvlToDim;
1048  // Return an empty lvlToDim when inference is not successful.
1049  if (!map || map.getNumSymbols() != 0) {
1050  lvlToDim = AffineMap();
1051  } else if (map.isPermutation()) {
1052  lvlToDim = inversePermutation(map);
1053  } else if (isBlockSparsity(map)) {
1054  lvlToDim = inverseBlockSparsity(map, context);
1055  }
1056  return lvlToDim;
1057 }
1058 
1060  MLIRContext *context) {
1061  SmallVector<AffineExpr> lvlExprs;
1062  auto numLvls = dimToLvl.getNumResults();
1063  lvlExprs.reserve(numLvls);
1064  // lvlExprComponents stores information of the floordiv and mod operations
1065  // applied to the same dimension, so as to build the lvlToDim map.
1066  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067  for (unsigned i = 0, n = numLvls; i < n; i++) {
1068  auto result = dimToLvl.getResult(i);
1069  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070  if (result.getKind() == AffineExprKind::FloorDiv) {
1071  // Position of the dimension in dimToLvl.
1072  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074  "expected only one floordiv for each dimension");
1075  SmallVector<AffineExpr, 3> components;
1076  // Level variable for floordiv.
1077  components.push_back(getAffineDimExpr(i, context));
1078  // Multiplier.
1079  components.push_back(binOp.getRHS());
1080  // Map key is the position of the dimension.
1081  lvlExprComponents[pos] = components;
1082  } else if (result.getKind() == AffineExprKind::Mod) {
1083  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085  "expected floordiv before mod");
1086  // Add level variable for mod to the same vector
1087  // of the corresponding floordiv.
1088  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089  } else {
1090  assert(false && "expected floordiv or mod");
1091  }
1092  } else {
1093  lvlExprs.push_back(getAffineDimExpr(i, context));
1094  }
1095  }
1096  // Build lvlExprs from lvlExprComponents.
1097  // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098  // would be [il, 2, ii]. It could be used to build the AffineExpr
1099  // i = il * 2 + ii in lvlToDim.
1100  for (auto &components : lvlExprComponents) {
1101  assert(components.second.size() == 3 &&
1102  "expected 3 components to build lvlExprs");
1103  auto mulOp = getAffineBinaryOpExpr(
1104  AffineExprKind::Mul, components.second[0], components.second[1]);
1105  auto addOp =
1106  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107  lvlExprs.push_back(addOp);
1108  }
1109  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110 }
1111 
1113  assert(isBlockSparsity(dimToLvl) &&
1114  "expected dimToLvl to be block sparsity for calling getBlockSize");
1115  SmallVector<unsigned> blockSize;
1116  for (auto result : dimToLvl.getResults()) {
1117  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1118  if (result.getKind() == AffineExprKind::Mod) {
1119  blockSize.push_back(
1120  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1121  }
1122  } else {
1123  blockSize.push_back(0);
1124  }
1125  }
1126  return blockSize;
1127 }
1128 
1130  if (!dimToLvl)
1131  return false;
1132  std::map<unsigned, int64_t> coeffientMap;
1133  bool hasBlock = false;
1134  for (auto result : dimToLvl.getResults()) {
1135  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1136  // Check for "dim op const".
1137  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139  if (!dimOp || !conOp || conOp.getValue() <= 0)
1140  return false;
1141  // Inspect "dim / const" or "dim % const".
1142  auto pos = dimOp.getPosition();
1143  if (binOp.getKind() == AffineExprKind::FloorDiv) {
1144  // Expect only one floordiv for each dimension.
1145  if (coeffientMap.find(pos) != coeffientMap.end())
1146  return false;
1147  // Record coefficient of the floordiv.
1148  coeffientMap[pos] = conOp.getValue();
1149  } else if (binOp.getKind() == AffineExprKind::Mod) {
1150  // Expect floordiv before mod.
1151  if (coeffientMap.find(pos) == coeffientMap.end())
1152  return false;
1153  // Expect mod to have the same coefficient as floordiv.
1154  if (conOp.getValue() != coeffientMap[pos])
1155  return false;
1156  hasBlock = true;
1157  } else {
1158  return false;
1159  }
1160  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1161  auto pos = dimOp.getPosition();
1162  // Expect dim to be unset.
1163  if (coeffientMap.find(pos) != coeffientMap.end())
1164  return false;
1165  coeffientMap[pos] = 0;
1166  } else {
1167  return false;
1168  }
1169  }
1170  return hasBlock;
1171 }
1172 
1174  auto hasNonIdentityMap = [](Value v) {
1175  auto stt = tryGetSparseTensorType(v);
1176  return stt && !stt->isIdentity();
1177  };
1178 
1179  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1180  llvm::any_of(op->getResults(), hasNonIdentityMap);
1181 }
1182 
1183 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1184  if (enc) {
1185  assert(enc.isPermutation() && "Non permutation map not supported");
1186  if (const auto dimToLvl = enc.getDimToLvl())
1187  return dimToLvl.getDimPosition(l);
1188  }
1189  return l;
1190 }
1191 
1192 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1193  if (enc) {
1194  assert(enc.isPermutation() && "Non permutation map not supported");
1195  if (const auto lvlToDim = enc.getLvlToDim())
1196  return lvlToDim.getDimPosition(d);
1197  }
1198  return d;
1199 }
1200 
1201 /// We normalized sparse tensor encoding attribute by always using
1202 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1203 /// as other variants) lead to the same storage specifier type, and stripping
1204 /// irrelevant fields that do not alter the sparse tensor memory layout.
1205 static SparseTensorEncodingAttr
1206 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1208  for (auto lt : enc.getLvlTypes())
1209  lts.push_back(lt.stripStorageIrrelevantProperties());
1210 
1212  enc.getContext(), lts,
1213  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1214  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1215  // Always use `index` for memSize and lvlSize instead of reusing
1216  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1217  // value for different bitwidth, it also avoids casting between index and
1218  // integer (returned by DimOp)
1219  0, 0,
1220  Attribute(), // explicitVal (irrelevant to storage specifier)
1221  Attribute(), // implicitVal (irrelevant to storage specifier)
1222  enc.getDimSlices());
1223 }
1224 
1225 StorageSpecifierType
1226 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1227  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1228 }
1229 
1230 //===----------------------------------------------------------------------===//
1231 // SparseTensorDialect Operations.
1232 //===----------------------------------------------------------------------===//
1233 
1234 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1235  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1236 }
1237 
1238 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1239  const Type etp = getMemRefType(mem).getElementType();
1240  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1241 }
1242 
1243 static LogicalResult verifySparsifierGetterSetter(
1244  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1246  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1247  return op->emitError(
1248  "redundant level argument for querying value memory size");
1249  }
1250 
1251  const auto enc = md.getType().getEncoding();
1252  const Level lvlRank = enc.getLvlRank();
1253 
1254  if (mdKind == StorageSpecifierKind::DimOffset ||
1255  mdKind == StorageSpecifierKind::DimStride)
1256  if (!enc.isSlice())
1257  return op->emitError("requested slice data on non-slice tensor");
1258 
1259  if (mdKind != StorageSpecifierKind::ValMemSize) {
1260  if (!lvl)
1261  return op->emitError("missing level argument");
1262 
1263  const Level l = lvl.value();
1264  if (l >= lvlRank)
1265  return op->emitError("requested level is out of bounds");
1266 
1267  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1268  return op->emitError(
1269  "requested position memory size on a singleton level");
1270  }
1271  return success();
1272 }
1273 
1275  switch (kind) {
1277  return stt.getCrdType();
1279  return stt.getPosType();
1281  return stt.getElementType();
1283  return nullptr;
1284  }
1285  llvm_unreachable("Unrecognizable FieldKind");
1286 }
1287 
1288 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1289  SparseTensorType stt,
1290  RankedTensorType valTp,
1291  TypeRange lvlTps) {
1292  if (requiresStaticShape && !stt.hasStaticDimShape())
1293  return op->emitError("the sparse-tensor must have static shape");
1294  if (!stt.hasEncoding())
1295  return op->emitError("the sparse-tensor must have an encoding attribute");
1296 
1297  // Verifies the trailing COO.
1298  Level cooStartLvl = stt.getAoSCOOStart();
1299  if (cooStartLvl < stt.getLvlRank()) {
1300  // We only supports trailing COO for now, must be the last input.
1301  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1302  // The coordinates should be in shape of <? x rank>
1303  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1304  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1305  op->emitError("input/output trailing COO level-ranks don't match");
1306  }
1307  }
1308 
1309  // Verifies that all types match.
1310  StorageLayout layout(stt.getEncoding());
1311  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1312  return op->emitError("inconsistent number of fields between input/output");
1313 
1314  unsigned idx = 0;
1315  bool misMatch = false;
1316  layout.foreachField([&idx, &misMatch, stt, valTp,
1317  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1318  Level lvl, LevelType lt) -> bool {
1320  return true;
1321 
1322  Type inputTp = nullptr;
1323  if (fKind == SparseTensorFieldKind::ValMemRef) {
1324  inputTp = valTp;
1325  } else {
1326  assert(fid == idx && stt.getLvlType(lvl) == lt);
1327  inputTp = lvlTps[idx++];
1328  }
1329  // The input element type and expected element type should match.
1330  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1331  Type expElemTp = getFieldElemType(stt, fKind);
1332  if (inpElemTp != expElemTp) {
1333  misMatch = true;
1334  return false; // to terminate the iteration
1335  }
1336  return true;
1337  });
1338 
1339  if (misMatch)
1340  return op->emitError("input/output element-types don't match");
1341  return success();
1342 }
1343 
1344 LogicalResult AssembleOp::verify() {
1345  const auto valuesTp = getRankedTensorType(getValues());
1346  const auto lvlsTp = getLevels().getTypes();
1347  const auto resTp = getSparseTensorType(getResult());
1348  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1349 }
1350 
1351 LogicalResult DisassembleOp::verify() {
1352  if (getOutValues().getType() != getRetValues().getType())
1353  return emitError("output values and return value type mismatch");
1354 
1355  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1356  if (ot.getType() != rt.getType())
1357  return emitError("output levels and return levels type mismatch");
1358 
1359  const auto valuesTp = getRankedTensorType(getRetValues());
1360  const auto lvlsTp = getRetLevels().getTypes();
1361  const auto srcTp = getSparseTensorType(getTensor());
1362  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1363 }
1364 
1365 LogicalResult ConvertOp::verify() {
1366  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1367  if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1368  if (tp1.getRank() != tp2.getRank())
1369  return emitError("unexpected conversion mismatch in rank");
1370  auto dstEnc =
1371  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1372  if (dstEnc && dstEnc.isSlice())
1373  return emitError("cannot convert to a sparse tensor slice");
1374 
1375  auto shape1 = tp1.getShape();
1376  auto shape2 = tp2.getShape();
1377  // Accept size matches between the source and the destination type
1378  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1379  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1380  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1381  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1382  return emitError("unexpected conversion mismatch in dimension ") << d;
1383  return success();
1384  }
1385  }
1386  return emitError("unexpected type in convert");
1387 }
1388 
1389 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1390  if (getType() == getSource().getType())
1391  return getSource();
1392  return {};
1393 }
1394 
1395 bool ConvertOp::needsExtraSort() {
1396  SparseTensorType srcStt = getSparseTensorType(getSource());
1397  SparseTensorType dstStt = getSparseTensorType(getDest());
1398 
1399  // We do not need an extra sort when returning unordered sparse tensors or
1400  // dense tensor since dense tensor support random access.
1401  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1402  return false;
1403 
1404  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1405  srcStt.hasSameDimToLvl(dstStt)) {
1406  return false;
1407  }
1408 
1409  // Source and dest tensors are ordered in different ways. We only do direct
1410  // dense to sparse conversion when the dense input is defined by a sparse
1411  // constant. Note that we can theoretically always directly convert from dense
1412  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1413  // performance.
1414  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1415  if (isa<SparseElementsAttr>(constOp.getValue()))
1416  return false;
1417 
1418  return true;
1419 }
1420 
1421 LogicalResult CrdTranslateOp::verify() {
1422  uint64_t inRank = getEncoder().getLvlRank();
1423  uint64_t outRank = getEncoder().getDimRank();
1424 
1425  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1426  std::swap(inRank, outRank);
1427 
1428  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1429  return emitError("Coordinate rank mismatch with encoding");
1430 
1431  return success();
1432 }
1433 
1434 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1435  SmallVectorImpl<OpFoldResult> &results) {
1436  if (getEncoder().isIdentity()) {
1437  results.assign(getInCrds().begin(), getInCrds().end());
1438  return success();
1439  }
1440  if (getEncoder().isPermutation()) {
1441  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1442  ? getEncoder().getDimToLvl()
1443  : getEncoder().getLvlToDim();
1444  for (AffineExpr exp : perm.getResults())
1445  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1446  return success();
1447  }
1448 
1449  // Fuse dim2lvl/lvl2dim pairs.
1450  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1451  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1452  return v.getDefiningOp() == def;
1453  });
1454  if (!sameDef)
1455  return failure();
1456 
1457  bool oppositeDir = def.getDirection() != getDirection();
1458  bool sameOracle =
1459  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1460  bool sameCount = def.getNumResults() == getInCrds().size();
1461  if (!oppositeDir || !sameOracle || !sameCount)
1462  return failure();
1463 
1464  // The definition produces the coordinates in the same order as the input
1465  // coordinates.
1466  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1467  [](auto valuePair) {
1468  auto [lhs, rhs] = valuePair;
1469  return lhs == rhs;
1470  });
1471 
1472  if (!sameOrder)
1473  return failure();
1474  // l1 = dim2lvl (lvl2dim l0)
1475  // ==> l0
1476  results.append(def.getInCrds().begin(), def.getInCrds().end());
1477  return success();
1478 }
1479 
1480 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1481  int64_t index) {
1482  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1483  return build(builder, state, source, val);
1484 }
1485 
1486 LogicalResult LvlOp::verify() {
1487  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1488  auto stt = getSparseTensorType(getSource());
1489  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1490  emitError("Level index exceeds the rank of the input sparse tensor");
1491  }
1492  return success();
1493 }
1494 
1495 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1496  return getConstantIntValue(getIndex());
1497 }
1498 
1499 Speculation::Speculatability LvlOp::getSpeculatability() {
1500  auto constantIndex = getConstantLvlIndex();
1501  if (!constantIndex)
1503 
1504  assert(constantIndex <
1505  cast<RankedTensorType>(getSource().getType()).getRank());
1507 }
1508 
1509 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1510  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1511  if (!lvlIndex)
1512  return {};
1513 
1514  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1515  auto stt = getSparseTensorType(getSource());
1516  if (lvl >= stt.getLvlRank()) {
1517  // Follows the same convention used by tensor.dim operation. Out of bound
1518  // indices produce undefined behavior but are still valid IR. Don't choke on
1519  // them.
1520  return {};
1521  }
1522 
1523  // Helper lambda to build an IndexAttr.
1524  auto getIndexAttr = [this](int64_t lvlSz) {
1525  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1526  };
1527 
1528  SmallVector<Size> lvlShape = stt.getLvlShape();
1529  if (!ShapedType::isDynamic(lvlShape[lvl]))
1530  return getIndexAttr(lvlShape[lvl]);
1531 
1532  return {};
1533 }
1534 
1535 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1536  SparseTensorEncodingAttr dstEnc, Value source) {
1537  auto srcStt = getSparseTensorType(source);
1538  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1539  SmallVector<int64_t> dstDimShape =
1540  dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1541  auto dstTp =
1542  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1543  return build(odsBuilder, odsState, dstTp, source);
1544 }
1545 
1546 LogicalResult ReinterpretMapOp::verify() {
1547  auto srcStt = getSparseTensorType(getSource());
1548  auto dstStt = getSparseTensorType(getDest());
1549  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1550  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1551 
1552  if (srcLvlTps.size() != dstLvlTps.size())
1553  return emitError("Level rank mismatch between source/dest tensors");
1554 
1555  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1556  if (srcLvlTp != dstLvlTp)
1557  return emitError("Level type mismatch between source/dest tensors");
1558 
1559  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1560  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1561  return emitError("Crd/Pos width mismatch between source/dest tensors");
1562  }
1563 
1564  if (srcStt.getElementType() != dstStt.getElementType())
1565  return emitError("Element type mismatch between source/dest tensors");
1566 
1567  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1568  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1569  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1570  if (srcLvlSz != dstLvlSz) {
1571  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1572  // compatible to <3x4>? For now, we require all the level sizes to be
1573  // *exactly* matched for simplicity.
1574  return emitError("Level size mismatch between source/dest tensors");
1575  }
1576  }
1577 
1578  return success();
1579 }
1580 
1581 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1582  if (getSource().getType() == getDest().getType())
1583  return getSource();
1584 
1585  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1586  // A -> B, B -> A ==> A
1587  if (def.getSource().getType() == getDest().getType())
1588  return def.getSource();
1589  }
1590  return {};
1591 }
1592 
1593 template <typename ToBufferOp>
1594 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1595  OpaqueProperties prop,
1596  RegionRange region,
1598  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1599  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1600  Type elemTp = nullptr;
1601  bool withStride = false;
1602  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1603  elemTp = stt.getPosType();
1604  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1605  std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1606  elemTp = stt.getCrdType();
1607  if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1608  withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1609  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1610  elemTp = stt.getElementType();
1611  }
1612 
1613  assert(elemTp && "unhandled operation.");
1614  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1615  bufShape.push_back(ShapedType::kDynamic);
1616 
1617  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1618  stt.getContext(), ShapedType::kDynamic,
1619  {ShapedType::kDynamic})
1620  : StridedLayoutAttr();
1621  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1622  return success();
1623 }
1624 
1625 LogicalResult ToPositionsOp::verify() {
1626  auto stt = getSparseTensorType(getTensor());
1627  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1628  return emitError("requested level is out of bounds");
1629  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1630  return emitError("unexpected type for positions");
1631  return success();
1632 }
1633 
1634 LogicalResult
1635 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1636  ValueRange ops, DictionaryAttr attr,
1637  OpaqueProperties prop, RegionRange region,
1639  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1640 }
1641 
1642 LogicalResult ToCoordinatesOp::verify() {
1643  auto stt = getSparseTensorType(getTensor());
1644  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1645  return emitError("requested level is out of bounds");
1646  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1647  return emitError("unexpected type for coordinates");
1648  return success();
1649 }
1650 
1651 LogicalResult
1652 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1653  ValueRange ops, DictionaryAttr attr,
1654  OpaqueProperties prop, RegionRange region,
1656  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1657 }
1658 
1659 LogicalResult ToCoordinatesBufferOp::verify() {
1660  auto stt = getSparseTensorType(getTensor());
1661  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1662  return emitError("expected sparse tensor with a COO region");
1663  return success();
1664 }
1665 
1666 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1667  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1668  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1670  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1671  ret);
1672 }
1673 
1674 LogicalResult ToValuesOp::verify() {
1675  auto stt = getSparseTensorType(getTensor());
1676  auto mtp = getMemRefType(getResult());
1677  if (stt.getElementType() != mtp.getElementType())
1678  return emitError("unexpected mismatch in element types");
1679  return success();
1680 }
1681 
1682 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1683  std::optional<Location> loc,
1684  ValueRange ops, DictionaryAttr attr,
1685  OpaqueProperties prop,
1686  RegionRange region,
1688  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1689 }
1690 
1691 LogicalResult ToSliceOffsetOp::verify() {
1692  auto rank = getRankedTensorType(getSlice()).getRank();
1693  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1694  return emitError("requested dimension out of bound");
1695  return success();
1696 }
1697 
1698 LogicalResult ToSliceStrideOp::verify() {
1699  auto rank = getRankedTensorType(getSlice()).getRank();
1700  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1701  return emitError("requested dimension out of bound");
1702  return success();
1703 }
1704 
1705 LogicalResult GetStorageSpecifierOp::verify() {
1706  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1707  getSpecifier(), getOperation());
1708 }
1709 
1710 template <typename SpecifierOp>
1711 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1712  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1713 }
1714 
1715 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1716  const StorageSpecifierKind kind = getSpecifierKind();
1717  const auto lvl = getLevel();
1718  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1719  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1720  return op.getValue();
1721  return {};
1722 }
1723 
1724 LogicalResult SetStorageSpecifierOp::verify() {
1725  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1726  getSpecifier(), getOperation());
1727 }
1728 
1729 template <class T>
1730 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1731  const char *regionName,
1732  TypeRange inputTypes, Type outputType) {
1733  unsigned numArgs = region.getNumArguments();
1734  unsigned expectedNum = inputTypes.size();
1735  if (numArgs != expectedNum)
1736  return op->emitError() << regionName << " region must have exactly "
1737  << expectedNum << " arguments";
1738 
1739  for (unsigned i = 0; i < numArgs; i++) {
1740  Type typ = region.getArgument(i).getType();
1741  if (typ != inputTypes[i])
1742  return op->emitError() << regionName << " region argument " << (i + 1)
1743  << " type mismatch";
1744  }
1745  Operation *term = region.front().getTerminator();
1746  YieldOp yield = dyn_cast<YieldOp>(term);
1747  if (!yield)
1748  return op->emitError() << regionName
1749  << " region must end with sparse_tensor.yield";
1750  if (!yield.hasSingleResult() ||
1751  yield.getSingleResult().getType() != outputType)
1752  return op->emitError() << regionName << " region yield type mismatch";
1753 
1754  return success();
1755 }
1756 
1757 LogicalResult BinaryOp::verify() {
1758  NamedAttrList attrs = (*this)->getAttrs();
1759  Type leftType = getX().getType();
1760  Type rightType = getY().getType();
1761  Type outputType = getOutput().getType();
1762  Region &overlap = getOverlapRegion();
1763  Region &left = getLeftRegion();
1764  Region &right = getRightRegion();
1765 
1766  // Check correct number of block arguments and return type for each
1767  // non-empty region.
1768  if (!overlap.empty()) {
1769  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1770  TypeRange{leftType, rightType}, outputType)))
1771  return failure();
1772  }
1773  if (!left.empty()) {
1774  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1775  outputType)))
1776  return failure();
1777  } else if (getLeftIdentity()) {
1778  if (leftType != outputType)
1779  return emitError("left=identity requires first argument to have the same "
1780  "type as the output");
1781  }
1782  if (!right.empty()) {
1783  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1784  outputType)))
1785  return failure();
1786  } else if (getRightIdentity()) {
1787  if (rightType != outputType)
1788  return emitError("right=identity requires second argument to have the "
1789  "same type as the output");
1790  }
1791  return success();
1792 }
1793 
1794 LogicalResult UnaryOp::verify() {
1795  Type inputType = getX().getType();
1796  Type outputType = getOutput().getType();
1797 
1798  // Check correct number of block arguments and return type for each
1799  // non-empty region.
1800  Region &present = getPresentRegion();
1801  if (!present.empty()) {
1802  if (failed(verifyNumBlockArgs(this, present, "present",
1803  TypeRange{inputType}, outputType)))
1804  return failure();
1805  }
1806  Region &absent = getAbsentRegion();
1807  if (!absent.empty()) {
1808  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1809  outputType)))
1810  return failure();
1811  // Absent branch can only yield invariant values.
1812  Block *absentBlock = &absent.front();
1813  Block *parent = getOperation()->getBlock();
1814  Value absentVal =
1815  cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1816  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1817  if (arg.getOwner() == parent)
1818  return emitError("absent region cannot yield linalg argument");
1819  } else if (Operation *def = absentVal.getDefiningOp()) {
1820  if (!isa<arith::ConstantOp>(def) &&
1821  (def->getBlock() == absentBlock || def->getBlock() == parent))
1822  return emitError("absent region cannot yield locally computed value");
1823  }
1824  }
1825  return success();
1826 }
1827 
1828 bool ConcatenateOp::needsExtraSort() {
1829  SparseTensorType dstStt = getSparseTensorType(*this);
1830  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1831  return false;
1832 
1833  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1834  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1835  });
1836  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1837  // in all input/output buffers, and all input/output buffers have the same
1838  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1839  // CSC matrices along column).
1840  bool directLowerable =
1841  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1842  return !directLowerable;
1843 }
1844 
1845 LogicalResult ConcatenateOp::verify() {
1846  const auto dstTp = getSparseTensorType(*this);
1847  const Dimension concatDim = getDimension();
1848  const Dimension dimRank = dstTp.getDimRank();
1849 
1850  if (getInputs().size() <= 1)
1851  return emitError("Need at least two tensors to concatenate.");
1852 
1853  if (concatDim >= dimRank)
1854  return emitError(llvm::formatv(
1855  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1856  concatDim, dimRank));
1857 
1858  for (const auto &it : llvm::enumerate(getInputs())) {
1859  const auto i = it.index();
1860  const auto srcTp = getSparseTensorType(it.value());
1861  if (srcTp.hasDynamicDimShape())
1862  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1863  const Dimension srcDimRank = srcTp.getDimRank();
1864  if (srcDimRank != dimRank)
1865  return emitError(
1866  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1867  "from the output tensor (rank={2}).",
1868  i, srcDimRank, dimRank));
1869  }
1870 
1871  for (Dimension d = 0; d < dimRank; d++) {
1872  const Size dstSh = dstTp.getDimShape()[d];
1873  if (d == concatDim) {
1874  if (!ShapedType::isDynamic(dstSh)) {
1875  // If we reach here, then all inputs have static shapes. So we
1876  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1877  // to avoid redundant assertions in the loop.
1878  Size sumSz = 0;
1879  for (const auto src : getInputs())
1880  sumSz += getSparseTensorType(src).getDimShape()[d];
1881  // If all dimension are statically known, the sum of all the input
1882  // dimensions should be equal to the output dimension.
1883  if (sumSz != dstSh)
1884  return emitError(
1885  "The concatenation dimension of the output tensor should be the "
1886  "sum of all the concatenation dimensions of the input tensors.");
1887  }
1888  } else {
1889  Size prev = dstSh;
1890  for (const auto src : getInputs()) {
1891  const auto sh = getSparseTensorType(src).getDimShape()[d];
1892  if (!ShapedType::isDynamic(prev) && sh != prev)
1893  return emitError("All dimensions (expect for the concatenating one) "
1894  "should be equal.");
1895  prev = sh;
1896  }
1897  }
1898  }
1899 
1900  return success();
1901 }
1902 
1903 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1904  Value curSize, Value inBuffer, Value value) {
1905  build(builder, result, curSize, inBuffer, value, Value());
1906 }
1907 
1908 LogicalResult PushBackOp::verify() {
1909  if (Value n = getN()) {
1910  std::optional<int64_t> nValue = getConstantIntValue(n);
1911  if (nValue && nValue.value() < 1)
1912  return emitOpError("n must be not less than 1");
1913  }
1914  return success();
1915 }
1916 
1917 LogicalResult CompressOp::verify() {
1918  const auto stt = getSparseTensorType(getTensor());
1919  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1920  return emitOpError("incorrect number of coordinates");
1921  return success();
1922 }
1923 
1924 void ForeachOp::build(
1925  OpBuilder &builder, OperationState &result, Value tensor,
1926  ValueRange initArgs, AffineMapAttr order,
1928  bodyBuilder) {
1929  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1930  // Builds foreach body.
1931  if (!bodyBuilder)
1932  return;
1933  const auto stt = getSparseTensorType(tensor);
1934  const Dimension dimRank = stt.getDimRank();
1935 
1936  // Starts with `dimRank`-many coordinates.
1937  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1938  // Followed by one value.
1939  blockArgTypes.push_back(stt.getElementType());
1940  // Followed by the reduction variables.
1941  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1942 
1943  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1944 
1945  OpBuilder::InsertionGuard guard(builder);
1946  auto &region = *result.regions.front();
1947  Block *bodyBlock =
1948  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1949  bodyBuilder(builder, result.location,
1950  bodyBlock->getArguments().slice(0, dimRank),
1951  bodyBlock->getArguments()[dimRank],
1952  bodyBlock->getArguments().drop_front(dimRank + 1));
1953 }
1954 
1955 LogicalResult ForeachOp::verify() {
1956  const auto t = getSparseTensorType(getTensor());
1957  const Dimension dimRank = t.getDimRank();
1958  const auto args = getBody()->getArguments();
1959 
1960  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1961  return emitError("Level traverse order does not match tensor's level rank");
1962 
1963  if (dimRank + 1 + getInitArgs().size() != args.size())
1964  return emitError("Unmatched number of arguments in the block");
1965 
1966  if (getNumResults() != getInitArgs().size())
1967  return emitError("Mismatch in number of init arguments and results");
1968 
1969  if (getResultTypes() != getInitArgs().getTypes())
1970  return emitError("Mismatch in types of init arguments and results");
1971 
1972  // Cannot mark this const, because the getters aren't.
1973  auto yield = cast<YieldOp>(getBody()->getTerminator());
1974  if (yield.getNumOperands() != getNumResults() ||
1975  yield.getOperands().getTypes() != getResultTypes())
1976  return emitError("Mismatch in types of yield values and results");
1977 
1978  const auto iTp = IndexType::get(getContext());
1979  for (Dimension d = 0; d < dimRank; d++)
1980  if (args[d].getType() != iTp)
1981  emitError(
1982  llvm::formatv("Expecting Index type for argument at index {0}", d));
1983 
1984  const auto elemTp = t.getElementType();
1985  const auto valueTp = args[dimRank].getType();
1986  if (elemTp != valueTp)
1987  emitError(llvm::formatv("Unmatched element type between input tensor and "
1988  "block argument, expected:{0}, got: {1}",
1989  elemTp, valueTp));
1990  return success();
1991 }
1992 
1993 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1994  if (getSparseTensorEncoding(getInputCoo().getType()) ==
1995  getSparseTensorEncoding(getResultCoo().getType()))
1996  return getInputCoo();
1997 
1998  return {};
1999 }
2000 
2001 LogicalResult ReorderCOOOp::verify() {
2002  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2003  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2004 
2005  if (!srcStt.isCOOType() || !dstStt.isCOOType())
2006  emitError("Expected COO sparse tensors only");
2007 
2008  if (!srcStt.hasSameDimToLvl(dstStt))
2009  emitError("Unmatched dim2lvl map between input and result COO");
2010 
2011  if (srcStt.getPosType() != dstStt.getPosType() ||
2012  srcStt.getCrdType() != dstStt.getCrdType() ||
2013  srcStt.getElementType() != dstStt.getElementType())
2014  emitError("Unmatched storage format between input and result COO");
2015 
2016  return success();
2017 }
2018 
2019 LogicalResult ReduceOp::verify() {
2020  Type inputType = getX().getType();
2021  Region &formula = getRegion();
2022  return verifyNumBlockArgs(this, formula, "reduce",
2023  TypeRange{inputType, inputType}, inputType);
2024 }
2025 
2026 LogicalResult SelectOp::verify() {
2027  Builder b(getContext());
2028  Type inputType = getX().getType();
2029  Type boolType = b.getI1Type();
2030  Region &formula = getRegion();
2031  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2032  boolType);
2033 }
2034 
2035 LogicalResult SortOp::verify() {
2036  AffineMap xPerm = getPermMap();
2037  uint64_t nx = xPerm.getNumDims();
2038  if (nx < 1)
2039  emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2040 
2041  if (!xPerm.isPermutation())
2042  emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2043 
2044  // We can't check the size of the buffers when n or buffer dimensions aren't
2045  // compile-time constants.
2046  std::optional<int64_t> cn = getConstantIntValue(getN());
2047  if (!cn)
2048  return success();
2049 
2050  // Verify dimensions.
2051  const auto checkDim = [&](Value v, Size minSize, const char *message) {
2052  const Size sh = getMemRefType(v).getShape()[0];
2053  if (!ShapedType::isDynamic(sh) && sh < minSize)
2054  emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2055  };
2056  uint64_t n = cn.value();
2057  uint64_t ny = 0;
2058  if (auto nyAttr = getNyAttr())
2059  ny = nyAttr.getInt();
2060  checkDim(getXy(), n * (nx + ny),
2061  "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2062  for (Value opnd : getYs())
2063  checkDim(opnd, n, "Expected dimension(y) >= n");
2064 
2065  return success();
2066 }
2067 
2068 //===----------------------------------------------------------------------===//
2069 // Sparse Tensor Iteration Operations.
2070 //===----------------------------------------------------------------------===//
2071 
2072 IterSpaceType IteratorType::getIterSpaceType() const {
2073  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2074  getHiLvl());
2075 }
2076 
2077 IteratorType IterSpaceType::getIteratorType() const {
2078  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2079 }
2080 
2081 /// Parses a level range in the form "$lo `to` $hi"
2082 /// or simply "$lo" if $hi - $lo = 1
2083 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2084  Level &lvlHi) {
2085  if (parser.parseInteger(lvlLo))
2086  return failure();
2087 
2088  if (succeeded(parser.parseOptionalKeyword("to"))) {
2089  if (parser.parseInteger(lvlHi))
2090  return failure();
2091  } else {
2092  lvlHi = lvlLo + 1;
2093  }
2094 
2095  if (lvlHi <= lvlLo)
2096  parser.emitError(parser.getNameLoc(),
2097  "expect larger level upper bound than lower bound");
2098 
2099  return success();
2100 }
2101 
2102 /// Parses a level range in the form "$lo `to` $hi"
2103 /// or simply "$lo" if $hi - $lo = 1
2104 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2105  IntegerAttr &lvlHiAttr) {
2106  Level lvlLo, lvlHi;
2107  if (parseLevelRange(parser, lvlLo, lvlHi))
2108  return failure();
2109 
2110  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2111  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2112  return success();
2113 }
2114 
2115 /// Prints a level range in the form "$lo `to` $hi"
2116 /// or simply "$lo" if $hi - $lo = 1
2117 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2118 
2119  if (lo + 1 == hi)
2120  p << lo;
2121  else
2122  p << lo << " to " << hi;
2123 }
2124 
2125 /// Prints a level range in the form "$lo `to` $hi"
2126 /// or simply "$lo" if $hi - $lo = 1
2127 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2128  IntegerAttr lvlHi) {
2129  unsigned lo = lvlLo.getValue().getZExtValue();
2130  unsigned hi = lvlHi.getValue().getZExtValue();
2131  printLevelRange(p, lo, hi);
2132 }
2133 
2134 static ParseResult
2140 
2141  // Parse "%iters, ... in %spaces, ..."
2142  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2143  parser.parseOperandList(spaces))
2144  return failure();
2145 
2146  if (iterators.size() != spaces.size())
2147  return parser.emitError(
2148  parser.getNameLoc(),
2149  "mismatch in number of sparse iterators and sparse spaces");
2150 
2151  // Parse "at(%crd0, _, ...)"
2152  LevelSet crdUsedLvlSet;
2153  bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
2154  unsigned lvlCrdCnt = 0;
2155  if (hasUsedCrds) {
2156  ParseResult crdList = parser.parseCommaSeparatedList(
2157  OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
2158  if (parser.parseOptionalKeyword("_")) {
2159  if (parser.parseArgument(iterArgs.emplace_back()))
2160  return failure();
2161  // Always use IndexType for the coordinate.
2162  crdUsedLvlSet.set(lvlCrdCnt);
2163  iterArgs.back().type = parser.getBuilder().getIndexType();
2164  }
2165  lvlCrdCnt += 1;
2166  return success();
2167  });
2168  if (failed(crdList)) {
2169  return parser.emitError(
2170  parser.getNameLoc(),
2171  "expecting SSA value or \"_\" for level coordinates");
2172  }
2173  }
2174  // Set the CrdUsedLvl bitset.
2175  state.addAttribute("crdUsedLvls",
2176  parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2177 
2178  // Parse "iter_args(%arg = %init, ...)"
2179  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2180  if (hasIterArgs)
2181  if (parser.parseAssignmentList(iterArgs, initArgs))
2182  return failure();
2183 
2184  SmallVector<Type> iterSpaceTps;
2185  // parse ": sparse_tensor.iter_space -> ret"
2186  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2187  return failure();
2188  if (iterSpaceTps.size() != spaces.size())
2189  return parser.emitError(parser.getNameLoc(),
2190  "mismatch in number of iteration space operands "
2191  "and iteration space types");
2192 
2193  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2194  IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2195  if (!spaceTp)
2196  return parser.emitError(parser.getNameLoc(),
2197  "expected sparse_tensor.iter_space type for "
2198  "iteration space operands");
2199  if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
2200  return parser.emitError(parser.getNameLoc(),
2201  "mismatch in number of iteration space dimension "
2202  "and specified coordinates");
2203  it.type = spaceTp.getIteratorType();
2204  }
2205 
2206  if (hasIterArgs)
2207  if (parser.parseArrowTypeList(state.types))
2208  return failure();
2209 
2210  // Resolves input operands.
2211  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2212  state.operands))
2213  return failure();
2214 
2215  if (hasIterArgs) {
2216  unsigned numCrds = crdUsedLvlSet.count();
2217  // Strip off leading args that used for coordinates.
2218  MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
2219  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2220  return parser.emitError(
2221  parser.getNameLoc(),
2222  "mismatch in number of iteration arguments and return values");
2223  }
2224 
2225  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2226  it.type = tp;
2227  if (parser.resolveOperand(init, tp, state.operands))
2228  return failure();
2229  }
2230  }
2231  return success();
2232 }
2233 
2234 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2235  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2236  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2238 
2239  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2240  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2241  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2242  adaptor.getHiLvl()));
2243  return success();
2244 }
2245 
2246 LogicalResult ExtractIterSpaceOp::verify() {
2247  if (getLoLvl() >= getHiLvl())
2248  return emitOpError("expected smaller level low than level high");
2249 
2250  TypedValue<IteratorType> pIter = getParentIter();
2251  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2252  return emitOpError(
2253  "parent iterator should be specified iff level lower bound equals 0");
2254  }
2255 
2256  if (pIter) {
2257  IterSpaceType spaceTp = getExtractedSpace().getType();
2258  if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2259  return emitOpError(
2260  "mismatch in parent iterator encoding and iteration space encoding.");
2261 
2262  if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2263  return emitOpError("parent iterator should be used to extract an "
2264  "iteration space from a consecutive level.");
2265  }
2266 
2267  return success();
2268 }
2269 
2270 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2272 
2273  LogicalResult matchAndRewrite(IterateOp iterateOp,
2274  PatternRewriter &rewriter) const override {
2275  LevelSet newUsedLvls(0);
2276  llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2277  for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2278  if (auto crd = iterateOp.getLvlCrd(i)) {
2279  if (crd->getUsers().empty())
2280  toRemove.set(crd->getArgNumber());
2281  else
2282  newUsedLvls.set(i);
2283  }
2284  }
2285 
2286  // All coordinates are used.
2287  if (toRemove.none())
2288  return failure();
2289 
2290  rewriter.startOpModification(iterateOp);
2291  iterateOp.setCrdUsedLvls(newUsedLvls);
2292  iterateOp.getBody()->eraseArguments(toRemove);
2293  rewriter.finalizeOpModification(iterateOp);
2294  return success();
2295  }
2296 };
2297 
2298 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2299  mlir::MLIRContext *context) {
2300  results.add<RemoveUnusedLvlCrds>(context);
2301 }
2302 
2303 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2304  Value iterSpace, ValueRange initArgs) {
2305  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2306  // All ones.
2307  LevelSet set((1 << rank) - 1);
2308  return build(builder, odsState, iterSpace, initArgs, set);
2309 }
2310 
2311 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2312  Value iterSpace, ValueRange initArgs,
2313  LevelSet crdUsedLvls) {
2314  OpBuilder::InsertionGuard guard(builder);
2315 
2316  odsState.addOperands(iterSpace);
2317  odsState.addOperands(initArgs);
2318  odsState.getOrAddProperties<Properties>().crdUsedLvls =
2319  builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2320  Region *bodyRegion = odsState.addRegion();
2321  odsState.addTypes(initArgs.getTypes());
2322  Block *bodyBlock = builder.createBlock(bodyRegion);
2323 
2324  // First argument, sparse iterator
2325  bodyBlock->addArgument(
2326  llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2327  odsState.location);
2328 
2329  // Followed by a list of used coordinates.
2330  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2331  bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2332 
2333  // Followed by a list of user-provided loop arguments.
2334  for (Value v : initArgs)
2335  bodyBlock->addArgument(v.getType(), v.getLoc());
2336 }
2337 
2338 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2339  OpAsmParser::Argument iterator;
2341 
2342  SmallVector<OpAsmParser::Argument> iters, iterArgs;
2343  if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
2344  return failure();
2345  if (iters.size() != 1)
2346  return parser.emitError(parser.getNameLoc(),
2347  "expected only one iterator/iteration space");
2348 
2349  iters.append(iterArgs);
2350  Region *body = result.addRegion();
2351  if (parser.parseRegion(*body, iters))
2352  return failure();
2353 
2354  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2355 
2356  // Parse the optional attribute list.
2357  if (parser.parseOptionalAttrDict(result.attributes))
2358  return failure();
2359 
2360  return success();
2361 }
2362 
2363 /// Prints the initialization list in the form of
2364 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2365 /// where 'inner' values are assumed to be region arguments and 'outer' values
2366 /// are regular SSA values.
2368  Block::BlockArgListType blocksArgs,
2369  ValueRange initializers,
2370  StringRef prefix = "") {
2371  assert(blocksArgs.size() == initializers.size() &&
2372  "expected same length of arguments and initializers");
2373  if (initializers.empty())
2374  return;
2375 
2376  p << prefix << '(';
2377  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2378  p << std::get<0>(it) << " = " << std::get<1>(it);
2379  });
2380  p << ")";
2381 }
2382 
2383 static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
2384  Block::BlockArgListType blocksArgs,
2385  LevelSet crdUsedLvls) {
2386  if (crdUsedLvls.empty())
2387  return;
2388 
2389  p << " at(";
2390  for (unsigned i = 0; i < spaceDim; i++) {
2391  if (crdUsedLvls[i]) {
2392  p << blocksArgs.front();
2393  blocksArgs = blocksArgs.drop_front();
2394  } else {
2395  p << "_";
2396  }
2397  if (i != spaceDim - 1)
2398  p << ", ";
2399  }
2400  assert(blocksArgs.empty());
2401  p << ")";
2402 }
2403 
2404 void IterateOp::print(OpAsmPrinter &p) {
2405  p << " " << getIterator() << " in " << getIterSpace();
2406  printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2407  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2408 
2409  p << " : " << getIterSpace().getType() << " ";
2410  if (!getInitArgs().empty())
2411  p << "-> (" << getInitArgs().getTypes() << ") ";
2412 
2413  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2414  /*printBlockTerminators=*/!getInitArgs().empty());
2415 }
2416 
2417 LogicalResult IterateOp::verify() {
2418  if (getInitArgs().size() != getNumResults()) {
2419  return emitOpError(
2420  "mismatch in number of loop-carried values and defined values");
2421  }
2422  if (getCrdUsedLvls().max() > getSpaceDim())
2423  return emitOpError("required out-of-bound coordinates");
2424 
2425  return success();
2426 }
2427 
2428 LogicalResult IterateOp::verifyRegions() {
2429  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2430  return emitOpError("mismatch in iterator and iteration space type");
2431  if (getNumRegionIterArgs() != getNumResults())
2432  return emitOpError(
2433  "mismatch in number of basic block args and defined values");
2434 
2435  auto initArgs = getInitArgs();
2436  auto iterArgs = getRegionIterArgs();
2437  auto yieldVals = getYieldedValues();
2438  auto opResults = getResults();
2439  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2440  opResults.size()})) {
2441  return emitOpError() << "number mismatch between iter args and results.";
2442  }
2443 
2444  for (auto [i, init, iter, yield, ret] :
2445  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2446  if (init.getType() != ret.getType())
2447  return emitOpError() << "types mismatch between " << i
2448  << "th iter operand and defined value";
2449  if (iter.getType() != ret.getType())
2450  return emitOpError() << "types mismatch between " << i
2451  << "th iter region arg and defined value";
2452  if (yield.getType() != ret.getType())
2453  return emitOpError() << "types mismatch between " << i
2454  << "th yield value and defined value";
2455  }
2456 
2457  return success();
2458 }
2459 
2460 /// OpInterfaces' methods implemented by IterateOp.
2461 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2462 
2463 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2464  return getInitArgsMutable();
2465 }
2466 
2467 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2468  return getRegion().getArguments().take_back(getNumRegionIterArgs());
2469 }
2470 
2471 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2472  return cast<sparse_tensor::YieldOp>(
2473  getRegion().getBlocks().front().getTerminator())
2474  .getResultsMutable();
2475 }
2476 
2477 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2478 
2479 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2480  return getInitArgs();
2481 }
2482 
2483 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2485  // Both the operation itself and the region may be branching into the body or
2486  // back into the operation itself.
2487  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2488  // It is possible for loop not to enter the body.
2489  regions.push_back(RegionSuccessor(getResults()));
2490 }
2491 
2492 //===----------------------------------------------------------------------===//
2493 // Sparse Tensor Dialect Setups.
2494 //===----------------------------------------------------------------------===//
2495 
2496 /// Materialize a single constant operation from a given attribute value with
2497 /// the desired resultant type.
2499  Attribute value, Type type,
2500  Location loc) {
2501  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2502  return op;
2503  return nullptr;
2504 }
2505 
2506 namespace {
2507 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2509 
2510  AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2511  if (isa<SparseTensorEncodingAttr>(attr)) {
2512  os << "sparse";
2513  return AliasResult::OverridableAlias;
2514  }
2515  return AliasResult::NoAlias;
2516  }
2517 };
2518 } // namespace
2519 
2520 void SparseTensorDialect::initialize() {
2521  addInterface<SparseTensorAsmDialectInterface>();
2522  addAttributes<
2523 #define GET_ATTRDEF_LIST
2524 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2525  >();
2526  addTypes<
2527 #define GET_TYPEDEF_LIST
2528 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2529  >();
2530  addOperations<
2531 #define GET_OP_LIST
2532 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2533  >();
2534  declarePromisedInterfaces<
2535  bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2536  NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2537  ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2538 }
2539 
2540 #define GET_OP_CLASSES
2541 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2542 
2543 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:113
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &iterArgs)
static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim, Block::BlockArgListType blocksArgs, LevelSet crdUsedLvls)
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:369
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:625
The possible results of an alias query.
Definition: AliasAnalysis.h:26
@ NoAlias
The two locations do not alias at all.
Definition: AliasAnalysis.h:34
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ Paren
Parens surrounding zero or more operands.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
IndexType getIndexType()
Definition: Builders.cpp:75
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:57
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:59
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
LevelSet & set(unsigned i)
Definition: SparseTensor.h:72
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
bool isOrderedLT(LevelType lt)
Definition: Enums.h:425
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:442
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
llvm::hash_code hash_value(LevelType lt)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:127
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:135
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:402
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
Definition: SparseTensor.h:50
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299