MLIR  21.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53  return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62  switch (bitWidth) {
63  case 0:
64  case 8:
65  case 16:
66  case 32:
67  case 64:
68  return true;
69  default:
70  return false;
71  }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76  std::optional<ArrayRef<int64_t>> dimShape) {
77  assert(enc);
78  // With only encoding, we can not determine the static shape for leading
79  // batch levels, we therefore return a dynamic shape memref instead.
80  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81  if (dimShape.has_value()) {
82  // If the actual tensor shape is provided, we can then refine the leading
83  // batch dimension.
84  SmallVector<int64_t> lvlShape =
85  enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86  memrefShape.assign(lvlShape.begin(),
87  lvlShape.begin() + enc.getBatchLvlRank());
88  }
89  // Another dynamic dimension to store the sparse level.
90  memrefShape.push_back(ShapedType::kDynamic);
91  return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
104  LevelType)>
105  callback) const {
106  const auto lvlTypes = enc.getLvlTypes();
107  const Level lvlRank = enc.getLvlRank();
108  SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
110 
111  ArrayRef cooSegsRef = cooSegs;
112  // Per-level storage.
113  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114  const auto lt = lvlTypes[l];
115  if (isWithPosLT(lt)) {
116  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117  return;
118  }
119  if (isWithCrdLT(lt)) {
120  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121  return;
122  }
123  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124  if (!cooSegsRef.front().isSoA) {
125  // AoS COO, all singletons are fused into one memrefs. Skips the entire
126  // COO segement.
127  l = cooSegsRef.front().lvlRange.second;
128  } else {
129  // SoA COO, each singleton level has one memref.
130  l++;
131  }
132  // Expire handled COO segment.
133  cooSegsRef = cooSegsRef.drop_front();
134  } else {
135  // Non COO levels.
136  l++;
137  }
138  }
139  // The values array.
140  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
142  return;
143  // Put metadata at the end.
144  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
146  return;
147 }
148 
150  SparseTensorType stt,
152  LevelType)>
153  callback) {
154  assert(stt.hasEncoding());
155 
156  SmallVector<int64_t> memrefShape =
158 
159  const Type specType = StorageSpecifierType::get(stt.getEncoding());
160  // memref<[batch] x ? x pos> positions
161  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162  // memref<[batch] x ? x crd> coordinates
163  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164  // memref<[batch] x ? x eltType> values
165  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168  callback](FieldIndex fieldIdx,
169  SparseTensorFieldKind fieldKind,
170  Level lvl, LevelType lt) -> bool {
171  switch (fieldKind) {
173  return callback(specType, fieldIdx, fieldKind, lvl, lt);
175  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
177  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
179  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180  };
181  llvm_unreachable("unrecognized field kind");
182  });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186  unsigned numFields = 0;
188  LevelType) -> bool {
189  numFields++;
190  return true;
191  });
192  return numFields;
193 }
194 
196  unsigned numFields = 0; // one value memref
198  LevelType) -> bool {
199  if (fidx >= kDataFieldStartingIdx)
200  numFields++;
201  return true;
202  });
203  numFields -= 1; // the last field is StorageSpecifier
204  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205  return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
210  std::optional<Level> lvl) const {
211  FieldIndex fieldIdx = kInvalidFieldIndex;
212  unsigned stride = 1;
214  assert(lvl.has_value());
215  const Level cooStart = enc.getAoSCOOStart();
216  const Level lvlRank = enc.getLvlRank();
217  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218  lvl = cooStart;
219  stride = lvlRank - cooStart;
220  }
221  }
222  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223  SparseTensorFieldKind fKind, Level fLvl,
224  LevelType lt) -> bool {
225  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227  fieldIdx = fIdx;
228  // Returns false to break the iteration.
229  return false;
230  }
231  return true;
232  });
233  assert(fieldIdx != kInvalidFieldIndex);
234  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242  return isDynamic(v) ? std::nullopt
243  : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247  return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251  return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255  return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259  return isDynamic(getOffset()) && isDynamic(getStride()) &&
260  isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264  return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269  os << '(';
270  os << getStaticString(getOffset());
271  os << ", ";
272  os << getStaticString(getSize());
273  os << ", ";
274  os << getStaticString(getStride());
275  os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279  print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283  AsmParser &parser) {
284  auto parseResult = parser.parseOptionalInteger(result);
285  if (parseResult.has_value()) {
286  if (parseResult.value().succeeded() && result < 0) {
287  parser.emitError(
288  parser.getCurrentLocation(),
289  "expect positive value or ? for slice offset/size/stride");
290  return failure();
291  }
292  return parseResult.value();
293  }
294 
295  // Else, and '?' which represented dynamic slice
296  result = SparseTensorDimSliceAttr::kDynamic;
297  return parser.parseQuestion();
298 }
299 
301  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303  if (failed(parser.parseLParen()) ||
304  failed(parseOptionalStaticSlice(offset, parser)) ||
305  failed(parser.parseComma()) ||
306  failed(parseOptionalStaticSlice(size, parser)) ||
307  failed(parser.parseComma()) ||
308  failed(parseOptionalStaticSlice(stride, parser)) ||
309  failed(parser.parseRParen()))
310  return {};
311 
312  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313  offset, size, stride);
314 }
315 
316 LogicalResult
318  int64_t offset, int64_t size, int64_t stride) {
319  if (!isDynamic(offset) && offset < 0)
320  return emitError() << "expect non-negative value or ? for slice offset";
321  if (!isDynamic(size) && size <= 0)
322  return emitError() << "expect positive value or ? for slice size";
323  if (!isDynamic(stride) && stride <= 0)
324  return emitError() << "expect positive value or ? for slice stride";
325  return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
332  getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333  getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342  return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347  unsigned crdWidth) const {
348  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
350  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351  crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355  return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
362  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363  getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367  return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
374  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375  getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379  return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
385  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386  getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394  ArrayRef<LevelType> lvlTypes = getLvlTypes();
395  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396  return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
400  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408  if (!getImpl())
409  return nullptr;
410  if (getCrdWidth())
411  return IntegerType::get(getContext(), getCrdWidth());
412  return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416  if (!getImpl())
417  return nullptr;
418  if (getPosWidth())
419  return IntegerType::get(getContext(), getPosWidth());
420  return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424  std::optional<ArrayRef<int64_t>> dimShape) const {
425  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426  return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430  std::optional<ArrayRef<int64_t>> dimShape) const {
431  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432  return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
440  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445  const auto dimToLvl = getDimToLvl();
446  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451  return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455  if (!getImpl())
456  return LevelFormat::Batch;
457  assert(l < getLvlRank() && "Level is out of bounds");
458  return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463  return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468  assert(isSlice() && "Is not a slice");
469  const auto dimSlices = getDimSlices();
470  assert(dim < dimSlices.size() && "Dimension is out of bounds");
471  return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476  return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481  return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486  return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491  return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496  CrdTransDirectionKind dir) const {
497  if (isIdentity())
498  return SmallVector<int64_t>(srcShape);
499 
501  unsigned rank =
502  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503  ret.reserve(rank);
504 
505  if (isPermutation()) {
506  for (unsigned r = 0; r < rank; r++) {
507  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508  : toLvl(*this, r);
509  ret.push_back(srcShape[trans]);
510  }
511  return ret;
512  }
513 
514  // Handle non-permutation maps.
515  AffineMap transMap =
516  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
519  dimRep.reserve(srcShape.size());
520  for (int64_t sz : srcShape) {
521  if (!ShapedType::isDynamic(sz)) {
522  // Push back the max coordinate for the given dimension/level size.
523  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524  } else {
525  // A dynamic size, use a AffineDimExpr to symbolize the value.
526  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527  }
528  };
529 
530  for (AffineExpr exp : transMap.getResults()) {
531  // Do constant propagation on the affine map.
532  AffineExpr evalExp =
533  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534  // use llvm namespace here to avoid ambiguity
535  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536  ret.push_back(c.getValue() + 1);
537  } else {
538  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539  mod && mod.getKind() == AffineExprKind::Mod) {
540  // We can still infer a static bound for expressions in form
541  // "d % constant" since d % constant \in [0, constant).
542  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543  ret.push_back(bound.getValue());
544  continue;
545  }
546  }
547  ret.push_back(ShapedType::kDynamic);
548  }
549  }
550  assert(ret.size() == rank);
551  return ret;
552 }
553 
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556  ValueRange crds,
557  CrdTransDirectionKind dir) const {
558  if (!getImpl())
559  return crds;
560 
561  SmallVector<Type> retType(
562  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563  builder.getIndexType());
564  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565  return transOp.getOutCrds();
566 }
567 
569  // Open "<{" part.
570  if (failed(parser.parseLess()))
571  return {};
572  if (failed(parser.parseLBrace()))
573  return {};
574 
575  // Process the data from the parsed dictionary value into struct-like data.
576  SmallVector<LevelType> lvlTypes;
578  AffineMap dimToLvl = {};
579  AffineMap lvlToDim = {};
580  unsigned posWidth = 0;
581  unsigned crdWidth = 0;
582  Attribute explicitVal;
583  Attribute implicitVal;
584  StringRef attrName;
585  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586  "explicitVal", "implicitVal"};
587  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588  // Detect admissible keyword.
589  auto *it = find(keys, attrName);
590  if (it == keys.end()) {
591  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592  return {};
593  }
594  unsigned keyWordIndex = it - keys.begin();
595  // Consume the `=` after keys
596  if (failed(parser.parseEqual()))
597  return {};
598  // Dispatch on keyword.
599  switch (keyWordIndex) {
600  case 0: { // map
601  ir_detail::DimLvlMapParser cParser(parser);
602  auto res = cParser.parseDimLvlMap();
603  if (failed(res))
604  return {};
605  const auto &dlm = *res;
606 
607  const Level lvlRank = dlm.getLvlRank();
608  for (Level lvl = 0; lvl < lvlRank; lvl++)
609  lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611  const Dimension dimRank = dlm.getDimRank();
612  for (Dimension dim = 0; dim < dimRank; dim++)
613  dimSlices.push_back(dlm.getDimSlice(dim));
614  // NOTE: the old syntax requires an all-or-nothing approach to
615  // `dimSlices`; therefore, if any slice actually exists then we need
616  // to convert null-DSA into default/nop DSA.
617  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618  return static_cast<bool>(slice.getImpl());
619  };
620  if (llvm::any_of(dimSlices, isDefined)) {
621  const auto defaultSlice =
623  for (Dimension dim = 0; dim < dimRank; dim++)
624  if (!isDefined(dimSlices[dim]))
625  dimSlices[dim] = defaultSlice;
626  } else {
627  dimSlices.clear();
628  }
629 
630  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632  break;
633  }
634  case 1: { // posWidth
635  Attribute attr;
636  if (failed(parser.parseAttribute(attr)))
637  return {};
638  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639  if (!intAttr) {
640  parser.emitError(parser.getNameLoc(),
641  "expected an integral position bitwidth");
642  return {};
643  }
644  posWidth = intAttr.getInt();
645  break;
646  }
647  case 2: { // crdWidth
648  Attribute attr;
649  if (failed(parser.parseAttribute(attr)))
650  return {};
651  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652  if (!intAttr) {
653  parser.emitError(parser.getNameLoc(),
654  "expected an integral index bitwidth");
655  return {};
656  }
657  crdWidth = intAttr.getInt();
658  break;
659  }
660  case 3: { // explicitVal
661  Attribute attr;
662  if (failed(parser.parseAttribute(attr)))
663  return {};
664  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665  explicitVal = result;
666  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667  explicitVal = result;
668  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669  explicitVal = result;
670  } else {
671  parser.emitError(parser.getNameLoc(),
672  "expected a numeric value for explicitVal");
673  return {};
674  }
675  break;
676  }
677  case 4: { // implicitVal
678  Attribute attr;
679  if (failed(parser.parseAttribute(attr)))
680  return {};
681  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682  implicitVal = result;
683  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684  implicitVal = result;
685  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686  implicitVal = result;
687  } else {
688  parser.emitError(parser.getNameLoc(),
689  "expected a numeric value for implicitVal");
690  return {};
691  }
692  break;
693  }
694  } // switch
695  // Only last item can omit the comma.
696  if (parser.parseOptionalComma().failed())
697  break;
698  }
699 
700  // Close "}>" part.
701  if (failed(parser.parseRBrace()))
702  return {};
703  if (failed(parser.parseGreater()))
704  return {};
705 
706  // Construct struct-like storage for attribute.
707  if (!lvlToDim || lvlToDim.isEmpty()) {
708  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709  }
710  return parser.getChecked<SparseTensorEncodingAttr>(
711  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712  explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716  auto map = static_cast<AffineMap>(getDimToLvl());
717  // Empty affine map indicates identity map
718  if (!map)
719  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720  printer << "<{ map = ";
721  printSymbols(map, printer);
722  printer << '(';
723  printDimensions(map, printer, getDimSlices());
724  printer << ") -> (";
725  printLevels(map, printer, getLvlTypes());
726  printer << ')';
727  // Print remaining members only for non-default values.
728  if (getPosWidth())
729  printer << ", posWidth = " << getPosWidth();
730  if (getCrdWidth())
731  printer << ", crdWidth = " << getCrdWidth();
732  if (getExplicitVal()) {
733  printer << ", explicitVal = " << getExplicitVal();
734  }
735  if (getImplicitVal())
736  printer << ", implicitVal = " << getImplicitVal();
737  printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741  AsmPrinter &printer) const {
742  if (map.getNumSymbols() == 0)
743  return;
744  printer << '[';
745  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746  printer << 's' << i << ", ";
747  if (map.getNumSymbols() >= 1)
748  printer << 's' << map.getNumSymbols() - 1;
749  printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753  AffineMap &map, AsmPrinter &printer,
754  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755  if (!dimSlices.empty()) {
756  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757  printer << 'd' << i << " : " << dimSlices[i] << ", ";
758  if (map.getNumDims() >= 1) {
759  printer << 'd' << map.getNumDims() - 1 << " : "
760  << dimSlices[map.getNumDims() - 1];
761  }
762  } else {
763  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764  printer << 'd' << i << ", ";
765  if (map.getNumDims() >= 1)
766  printer << 'd' << map.getNumDims() - 1;
767  }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771  ArrayRef<LevelType> lvlTypes) const {
772  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773  map.getResult(i).print(printer.getStream());
774  printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775  }
776  if (map.getNumResults() >= 1) {
777  auto lastIndex = map.getNumResults() - 1;
778  map.getResult(lastIndex).print(printer.getStream());
779  printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780  }
781 }
782 
785  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786  unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
788  if (!acceptBitWidth(posWidth))
789  return emitError() << "unexpected position bitwidth: " << posWidth;
790  if (!acceptBitWidth(crdWidth))
791  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793  // Verify every COO segment.
794  auto *it = llvm::find_if(lvlTypes, isSingletonLT);
795  while (it != lvlTypes.end()) {
796  if (it == lvlTypes.begin() ||
798  return emitError() << "expected compressed or loose_compressed level "
799  "before singleton level";
800 
801  auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802  if (!std::all_of(it, curCOOEnd, isSingletonLT))
803  return emitError() << "expected all singleton lvlTypes "
804  "following a singleton level";
805  // We can potentially support mixed SoA/AoS singleton levels.
806  if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
807  return it->isa<LevelPropNonDefault::SoA>() ==
809  })) {
810  return emitError() << "expected all singleton lvlTypes stored in the "
811  "same memory layout (SoA vs AoS).";
812  }
813  it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
814  }
815 
816  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
817  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
818  return emitError() << "Batch lvlType can only be leading levels.";
819 
820  // SoA property can only be applied on singleton level.
821  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
822  return lt.isa<LevelPropNonDefault::SoA>();
823  });
824  if (llvm::any_of(soaLvls, [](LevelType lt) {
825  return !lt.isa<LevelFormat::Singleton>();
826  })) {
827  return emitError() << "SoA is only applicable to singleton lvlTypes.";
828  }
829 
830  // TODO: audit formats that actually are supported by backend.
831  if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
832  it != std::end(lvlTypes)) {
833  if (it != lvlTypes.end() - 1)
834  return emitError() << "expected n_out_of_m to be the last level type";
835  if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
836  return emitError() << "expected all dense lvlTypes "
837  "before a n_out_of_m level";
838  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
839  if (!isBlockSparsity(dimToLvl)) {
840  return emitError()
841  << "expected 1xm block structure for n_out_of_m level";
842  }
843  auto sizes = getBlockSize(dimToLvl);
844  unsigned coefficient = 0;
845  for (const auto &elem : sizes) {
846  if (elem != 0) {
847  if (elem != coefficient && coefficient != 0) {
848  return emitError() << "expected only one blocked level "
849  "with the same coefficients";
850  }
851  coefficient = elem;
852  }
853  }
854  if (coefficient != getM(*it)) {
855  return emitError() << "expected coeffiencts of Affine expressions "
856  "to be equal to m of n_out_of_m level";
857  }
858  }
859  }
860  // Before we can check that the level-rank is consistent/coherent
861  // across all fields, we need to define it. The source-of-truth for
862  // the `getLvlRank` method is the length of the level-types array,
863  // since it must always be provided and have full rank; therefore we
864  // use that same source-of-truth here.
865  const Level lvlRank = lvlTypes.size();
866  if (lvlRank == 0)
867  return emitError() << "expected a non-empty array for lvlTypes";
868  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
869  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
870  if (dimToLvl) {
871  if (dimToLvl.getNumResults() != lvlRank)
872  return emitError()
873  << "level-rank mismatch between dimToLvl and lvlTypes: "
874  << dimToLvl.getNumResults() << " != " << lvlRank;
875  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
876  // Symbols can't be inferred but are acceptable.
877  if (!inferRes && dimToLvl.getNumSymbols() == 0)
878  return emitError() << "failed to infer lvlToDim from dimToLvl";
879  if (lvlToDim && (inferRes != lvlToDim))
880  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
881  if (dimRank > lvlRank)
882  return emitError() << "unexpected dimToLvl mapping from " << dimRank
883  << " to " << lvlRank;
884  }
885  if (!dimSlices.empty()) {
886  if (dimSlices.size() != dimRank)
887  return emitError()
888  << "dimension-rank mismatch between dimSlices and dimToLvl: "
889  << dimSlices.size() << " != " << dimRank;
890  // Compiler support for `dimSlices` currently requires that the two
891  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
892  if (dimRank != lvlRank)
893  return emitError()
894  << "dimSlices expected dimension-rank to match level-rank: "
895  << dimRank << " != " << lvlRank;
896  }
897  return success();
898 }
899 
900 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
901  ArrayRef<Size> dimShape, Type elementType,
903  // Check structural integrity. In particular, this ensures that the
904  // level-rank is coherent across all the fields.
905  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
906  getPosWidth(), getCrdWidth(), getExplicitVal(),
907  getImplicitVal(), getDimSlices())))
908  return failure();
909  // Check integrity with tensor type specifics. In particular, we
910  // need only check that the dimension-rank of the tensor agrees with
911  // the dimension-rank of the encoding.
912  const Dimension dimRank = dimShape.size();
913  if (dimRank == 0)
914  return emitError() << "expected non-scalar sparse tensor";
915  if (getDimRank() != dimRank)
916  return emitError()
917  << "dimension-rank mismatch between encoding and tensor shape: "
918  << getDimRank() << " != " << dimRank;
919  if (auto expVal = getExplicitVal()) {
920  Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
921  if (attrType != elementType) {
922  return emitError() << "explicit value type mismatch between encoding and "
923  << "tensor element type: " << attrType
924  << " != " << elementType;
925  }
926  }
927  if (auto impVal = getImplicitVal()) {
928  Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
929  if (attrType != elementType) {
930  return emitError() << "implicit value type mismatch between encoding and "
931  << "tensor element type: " << attrType
932  << " != " << elementType;
933  }
934  // Currently, we only support zero as the implicit value.
935  auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
936  auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
937  auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
938  if ((impFVal && impFVal.getValue().isNonZero()) ||
939  (impIntVal && !impIntVal.getValue().isZero()) ||
940  (impComplexVal && (impComplexVal.getImag().isNonZero() ||
941  impComplexVal.getReal().isNonZero()))) {
942  return emitError() << "implicit value must be zero";
943  }
944  }
945  return success();
946 }
947 
948 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
949  SmallVector<COOSegment> coo = getCOOSegments();
950  assert(coo.size() == 1 || coo.empty());
951  if (!coo.empty() && coo.front().isAoS()) {
952  return coo.front().lvlRange.first;
953  }
954  return getLvlRank();
955 }
956 
958 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
960  if (getLvlRank() <= 1)
961  return ret;
962 
963  ArrayRef<LevelType> lts = getLvlTypes();
964  Level l = 0;
965  while (l < getLvlRank()) {
966  auto lt = lts[l];
968  auto cur = lts.begin() + l;
969  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
970  return !lt.isa<LevelFormat::Singleton>();
971  });
972  unsigned cooLen = std::distance(cur, end);
973  if (cooLen > 1) {
974  // To support mixed SoA/AoS COO, we should break the segment when the
975  // storage scheme changes, for now we faithfully assume that all
976  // consecutive singleton levels have the same storage format as verified
977  // STEA.
978  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
979  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
980  }
981  l += cooLen;
982  } else {
983  l++;
984  }
985  }
986  return ret;
987 }
988 
989 //===----------------------------------------------------------------------===//
990 // SparseTensorType Methods.
991 //===----------------------------------------------------------------------===//
992 
994  bool isUnique) const {
995  if (!hasEncoding())
996  return false;
997  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
998  return false;
999  for (Level l = startLvl + 1; l < lvlRank; ++l)
1000  if (!isSingletonLvl(l))
1001  return false;
1002  // If isUnique is true, then make sure that the last level is unique,
1003  // that is, when lvlRank == 1, the only compressed level is unique,
1004  // and when lvlRank > 1, the last singleton is unique.
1005  return !isUnique || isUniqueLvl(lvlRank - 1);
1006 }
1007 
1008 RankedTensorType
1010  SmallVector<LevelType> lvlTypes;
1011  lvlTypes.reserve(lvlRank);
1012  // A non-unique compressed level at beginning (unless this is
1013  // also the last level, then it is unique).
1014  lvlTypes.push_back(
1015  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1016  if (lvlRank > 1) {
1017  // Followed by n-2 non-unique singleton levels.
1018  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1019  *buildLevelType(LevelFormat::Singleton, ordered, false));
1020  // Ends by a unique singleton level.
1021  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1022  }
1023  auto enc = SparseTensorEncodingAttr::get(
1024  getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1025  getCrdWidth(), getExplicitVal(), getImplicitVal());
1026  return RankedTensorType::get(getDimShape(), getElementType(), enc);
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // Convenience Methods.
1031 //===----------------------------------------------------------------------===//
1032 
1033 SparseTensorEncodingAttr
1035  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1036  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1037  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1038  return mdtp.getEncoding();
1039  return nullptr;
1040 }
1041 
1043  MLIRContext *context) {
1044  auto map = static_cast<AffineMap>(dimToLvl);
1045  AffineMap lvlToDim;
1046  // Return an empty lvlToDim when inference is not successful.
1047  if (!map || map.getNumSymbols() != 0) {
1048  lvlToDim = AffineMap();
1049  } else if (map.isPermutation()) {
1050  lvlToDim = inversePermutation(map);
1051  } else if (isBlockSparsity(map)) {
1052  lvlToDim = inverseBlockSparsity(map, context);
1053  }
1054  return lvlToDim;
1055 }
1056 
1058  MLIRContext *context) {
1059  SmallVector<AffineExpr> lvlExprs;
1060  auto numLvls = dimToLvl.getNumResults();
1061  lvlExprs.reserve(numLvls);
1062  // lvlExprComponents stores information of the floordiv and mod operations
1063  // applied to the same dimension, so as to build the lvlToDim map.
1064  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1065  for (unsigned i = 0, n = numLvls; i < n; i++) {
1066  auto result = dimToLvl.getResult(i);
1067  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1068  if (result.getKind() == AffineExprKind::FloorDiv) {
1069  // Position of the dimension in dimToLvl.
1070  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1071  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1072  "expected only one floordiv for each dimension");
1073  SmallVector<AffineExpr, 3> components;
1074  // Level variable for floordiv.
1075  components.push_back(getAffineDimExpr(i, context));
1076  // Multiplier.
1077  components.push_back(binOp.getRHS());
1078  // Map key is the position of the dimension.
1079  lvlExprComponents[pos] = components;
1080  } else if (result.getKind() == AffineExprKind::Mod) {
1081  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1082  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1083  "expected floordiv before mod");
1084  // Add level variable for mod to the same vector
1085  // of the corresponding floordiv.
1086  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1087  } else {
1088  assert(false && "expected floordiv or mod");
1089  }
1090  } else {
1091  lvlExprs.push_back(getAffineDimExpr(i, context));
1092  }
1093  }
1094  // Build lvlExprs from lvlExprComponents.
1095  // For example, for il = i floordiv 2 and ii = i mod 2, the components
1096  // would be [il, 2, ii]. It could be used to build the AffineExpr
1097  // i = il * 2 + ii in lvlToDim.
1098  for (auto &components : lvlExprComponents) {
1099  assert(components.second.size() == 3 &&
1100  "expected 3 components to build lvlExprs");
1101  auto mulOp = getAffineBinaryOpExpr(
1102  AffineExprKind::Mul, components.second[0], components.second[1]);
1103  auto addOp =
1104  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1105  lvlExprs.push_back(addOp);
1106  }
1107  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1108 }
1109 
1111  assert(isBlockSparsity(dimToLvl) &&
1112  "expected dimToLvl to be block sparsity for calling getBlockSize");
1113  SmallVector<unsigned> blockSize;
1114  for (auto result : dimToLvl.getResults()) {
1115  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1116  if (result.getKind() == AffineExprKind::Mod) {
1117  blockSize.push_back(
1118  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1119  }
1120  } else {
1121  blockSize.push_back(0);
1122  }
1123  }
1124  return blockSize;
1125 }
1126 
1128  if (!dimToLvl)
1129  return false;
1130  std::map<unsigned, int64_t> coeffientMap;
1131  bool hasBlock = false;
1132  for (auto result : dimToLvl.getResults()) {
1133  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1134  // Check for "dim op const".
1135  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1136  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1137  if (!dimOp || !conOp || conOp.getValue() <= 0)
1138  return false;
1139  // Inspect "dim / const" or "dim % const".
1140  auto pos = dimOp.getPosition();
1141  if (binOp.getKind() == AffineExprKind::FloorDiv) {
1142  // Expect only one floordiv for each dimension.
1143  auto [it, inserted] = coeffientMap.try_emplace(pos);
1144  if (!inserted)
1145  return false;
1146  // Record coefficient of the floordiv.
1147  it->second = conOp.getValue();
1148  } else if (binOp.getKind() == AffineExprKind::Mod) {
1149  // Expect floordiv before mod.
1150  auto it = coeffientMap.find(pos);
1151  if (it == coeffientMap.end())
1152  return false;
1153  // Expect mod to have the same coefficient as floordiv.
1154  if (conOp.getValue() != it->second)
1155  return false;
1156  hasBlock = true;
1157  } else {
1158  return false;
1159  }
1160  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1161  auto pos = dimOp.getPosition();
1162  // Expect dim to be unset.
1163  if (!coeffientMap.try_emplace(pos, 0).second)
1164  return false;
1165  } else {
1166  return false;
1167  }
1168  }
1169  return hasBlock;
1170 }
1171 
1173  auto hasNonIdentityMap = [](Value v) {
1174  auto stt = tryGetSparseTensorType(v);
1175  return stt && !stt->isIdentity();
1176  };
1177 
1178  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1179  llvm::any_of(op->getResults(), hasNonIdentityMap);
1180 }
1181 
1182 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1183  if (enc) {
1184  assert(enc.isPermutation() && "Non permutation map not supported");
1185  if (const auto dimToLvl = enc.getDimToLvl())
1186  return dimToLvl.getDimPosition(l);
1187  }
1188  return l;
1189 }
1190 
1191 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1192  if (enc) {
1193  assert(enc.isPermutation() && "Non permutation map not supported");
1194  if (const auto lvlToDim = enc.getLvlToDim())
1195  return lvlToDim.getDimPosition(d);
1196  }
1197  return d;
1198 }
1199 
1200 /// We normalized sparse tensor encoding attribute by always using
1201 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1202 /// as other variants) lead to the same storage specifier type, and stripping
1203 /// irrelevant fields that do not alter the sparse tensor memory layout.
1204 static SparseTensorEncodingAttr
1205 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1207  for (auto lt : enc.getLvlTypes())
1208  lts.push_back(lt.stripStorageIrrelevantProperties());
1209 
1211  enc.getContext(), lts,
1212  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1213  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1214  // Always use `index` for memSize and lvlSize instead of reusing
1215  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1216  // value for different bitwidth, it also avoids casting between index and
1217  // integer (returned by DimOp)
1218  0, 0,
1219  Attribute(), // explicitVal (irrelevant to storage specifier)
1220  Attribute(), // implicitVal (irrelevant to storage specifier)
1221  enc.getDimSlices());
1222 }
1223 
1224 StorageSpecifierType
1225 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1226  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1227 }
1228 
1229 StorageSpecifierType
1230 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1231  MLIRContext *ctx,
1232  SparseTensorEncodingAttr encoding) {
1233  return Base::getChecked(emitError, ctx,
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // SparseTensorDialect Operations.
1239 //===----------------------------------------------------------------------===//
1240 
1241 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1242  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1243 }
1244 
1245 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1246  const Type etp = getMemRefType(mem).getElementType();
1247  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1248 }
1249 
1250 static LogicalResult verifySparsifierGetterSetter(
1251  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1253  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1254  return op->emitError(
1255  "redundant level argument for querying value memory size");
1256  }
1257 
1258  const auto enc = md.getType().getEncoding();
1259  const Level lvlRank = enc.getLvlRank();
1260 
1261  if (mdKind == StorageSpecifierKind::DimOffset ||
1262  mdKind == StorageSpecifierKind::DimStride)
1263  if (!enc.isSlice())
1264  return op->emitError("requested slice data on non-slice tensor");
1265 
1266  if (mdKind != StorageSpecifierKind::ValMemSize) {
1267  if (!lvl)
1268  return op->emitError("missing level argument");
1269 
1270  const Level l = lvl.value();
1271  if (l >= lvlRank)
1272  return op->emitError("requested level is out of bounds");
1273 
1274  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1275  return op->emitError(
1276  "requested position memory size on a singleton level");
1277  }
1278  return success();
1279 }
1280 
1282  switch (kind) {
1284  return stt.getCrdType();
1286  return stt.getPosType();
1288  return stt.getElementType();
1290  return nullptr;
1291  }
1292  llvm_unreachable("Unrecognizable FieldKind");
1293 }
1294 
1295 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1296  SparseTensorType stt,
1297  RankedTensorType valTp,
1298  TypeRange lvlTps) {
1299  if (requiresStaticShape && !stt.hasStaticDimShape())
1300  return op->emitError("the sparse-tensor must have static shape");
1301  if (!stt.hasEncoding())
1302  return op->emitError("the sparse-tensor must have an encoding attribute");
1303 
1304  // Verifies the trailing COO.
1305  Level cooStartLvl = stt.getAoSCOOStart();
1306  if (cooStartLvl < stt.getLvlRank()) {
1307  // We only supports trailing COO for now, must be the last input.
1308  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1309  // The coordinates should be in shape of <? x rank>
1310  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1311  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1312  return op->emitError("input/output trailing COO level-ranks don't match");
1313  }
1314  }
1315 
1316  // Verifies that all types match.
1317  StorageLayout layout(stt.getEncoding());
1318  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1319  return op->emitError("inconsistent number of fields between input/output");
1320 
1321  unsigned idx = 0;
1322  bool misMatch = false;
1323  layout.foreachField([&idx, &misMatch, stt, valTp,
1324  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1325  Level lvl, LevelType lt) -> bool {
1327  return true;
1328 
1329  Type inputTp = nullptr;
1330  if (fKind == SparseTensorFieldKind::ValMemRef) {
1331  inputTp = valTp;
1332  } else {
1333  assert(fid == idx && stt.getLvlType(lvl) == lt);
1334  inputTp = lvlTps[idx++];
1335  }
1336  // The input element type and expected element type should match.
1337  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1338  Type expElemTp = getFieldElemType(stt, fKind);
1339  if (inpElemTp != expElemTp) {
1340  misMatch = true;
1341  return false; // to terminate the iteration
1342  }
1343  return true;
1344  });
1345 
1346  if (misMatch)
1347  return op->emitError("input/output element-types don't match");
1348  return success();
1349 }
1350 
1351 LogicalResult AssembleOp::verify() {
1352  RankedTensorType valuesTp = getValues().getType();
1353  const auto lvlsTp = getLevels().getTypes();
1354  const auto resTp = getSparseTensorType(getResult());
1355  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1356 }
1357 
1358 LogicalResult DisassembleOp::verify() {
1359  if (getOutValues().getType() != getRetValues().getType())
1360  return emitError("output values and return value type mismatch");
1361 
1362  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1363  if (ot.getType() != rt.getType())
1364  return emitError("output levels and return levels type mismatch");
1365 
1366  RankedTensorType valuesTp = getRetValues().getType();
1367  const auto lvlsTp = getRetLevels().getTypes();
1368  const auto srcTp = getSparseTensorType(getTensor());
1369  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1370 }
1371 
1372 LogicalResult ConvertOp::verify() {
1373  RankedTensorType tp1 = getSource().getType();
1374  RankedTensorType tp2 = getDest().getType();
1375  if (tp1.getRank() != tp2.getRank())
1376  return emitError("unexpected conversion mismatch in rank");
1377  auto dstEnc =
1378  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1379  if (dstEnc && dstEnc.isSlice())
1380  return emitError("cannot convert to a sparse tensor slice");
1381 
1382  auto shape1 = tp1.getShape();
1383  auto shape2 = tp2.getShape();
1384  // Accept size matches between the source and the destination type
1385  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1386  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1387  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1388  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1389  return emitError("unexpected conversion mismatch in dimension ") << d;
1390  return success();
1391 }
1392 
1393 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1394  if (getType() == getSource().getType())
1395  return getSource();
1396  return {};
1397 }
1398 
1399 bool ConvertOp::needsExtraSort() {
1400  SparseTensorType srcStt = getSparseTensorType(getSource());
1401  SparseTensorType dstStt = getSparseTensorType(getDest());
1402 
1403  // We do not need an extra sort when returning unordered sparse tensors or
1404  // dense tensor since dense tensor support random access.
1405  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1406  return false;
1407 
1408  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1409  srcStt.hasSameDimToLvl(dstStt)) {
1410  return false;
1411  }
1412 
1413  // Source and dest tensors are ordered in different ways. We only do direct
1414  // dense to sparse conversion when the dense input is defined by a sparse
1415  // constant. Note that we can theoretically always directly convert from dense
1416  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1417  // performance.
1418  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1419  if (isa<SparseElementsAttr>(constOp.getValue()))
1420  return false;
1421 
1422  return true;
1423 }
1424 
1425 LogicalResult CrdTranslateOp::verify() {
1426  uint64_t inRank = getEncoder().getLvlRank();
1427  uint64_t outRank = getEncoder().getDimRank();
1428 
1429  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1430  std::swap(inRank, outRank);
1431 
1432  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1433  return emitError("Coordinate rank mismatch with encoding");
1434 
1435  return success();
1436 }
1437 
1438 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1439  SmallVectorImpl<OpFoldResult> &results) {
1440  if (getEncoder().isIdentity()) {
1441  results.assign(getInCrds().begin(), getInCrds().end());
1442  return success();
1443  }
1444  if (getEncoder().isPermutation()) {
1445  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1446  ? getEncoder().getDimToLvl()
1447  : getEncoder().getLvlToDim();
1448  for (AffineExpr exp : perm.getResults())
1449  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1450  return success();
1451  }
1452 
1453  // Fuse dim2lvl/lvl2dim pairs.
1454  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1455  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1456  return v.getDefiningOp() == def;
1457  });
1458  if (!sameDef)
1459  return failure();
1460 
1461  bool oppositeDir = def.getDirection() != getDirection();
1462  bool sameOracle =
1463  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1464  bool sameCount = def.getNumResults() == getInCrds().size();
1465  if (!oppositeDir || !sameOracle || !sameCount)
1466  return failure();
1467 
1468  // The definition produces the coordinates in the same order as the input
1469  // coordinates.
1470  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1471  [](auto valuePair) {
1472  auto [lhs, rhs] = valuePair;
1473  return lhs == rhs;
1474  });
1475 
1476  if (!sameOrder)
1477  return failure();
1478  // l1 = dim2lvl (lvl2dim l0)
1479  // ==> l0
1480  results.append(def.getInCrds().begin(), def.getInCrds().end());
1481  return success();
1482 }
1483 
1484 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1485  int64_t index) {
1486  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1487  return build(builder, state, source, val);
1488 }
1489 
1490 LogicalResult LvlOp::verify() {
1491  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1492  auto stt = getSparseTensorType(getSource());
1493  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1494  return emitError(
1495  "Level index exceeds the rank of the input sparse tensor");
1496  }
1497  return success();
1498 }
1499 
1500 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1501  return getConstantIntValue(getIndex());
1502 }
1503 
1504 Speculation::Speculatability LvlOp::getSpeculatability() {
1505  auto constantIndex = getConstantLvlIndex();
1506  if (!constantIndex)
1508 
1509  assert(constantIndex <
1510  cast<RankedTensorType>(getSource().getType()).getRank());
1512 }
1513 
1514 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1515  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1516  if (!lvlIndex)
1517  return {};
1518 
1519  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1520  auto stt = getSparseTensorType(getSource());
1521  if (lvl >= stt.getLvlRank()) {
1522  // Follows the same convention used by tensor.dim operation. Out of bound
1523  // indices produce undefined behavior but are still valid IR. Don't choke on
1524  // them.
1525  return {};
1526  }
1527 
1528  // Helper lambda to build an IndexAttr.
1529  auto getIndexAttr = [this](int64_t lvlSz) {
1530  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1531  };
1532 
1533  SmallVector<Size> lvlShape = stt.getLvlShape();
1534  if (!ShapedType::isDynamic(lvlShape[lvl]))
1535  return getIndexAttr(lvlShape[lvl]);
1536 
1537  return {};
1538 }
1539 
1540 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1541  SparseTensorEncodingAttr dstEnc, Value source) {
1542  auto srcStt = getSparseTensorType(source);
1543  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1544  SmallVector<int64_t> dstDimShape =
1545  dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1546  auto dstTp =
1547  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1548  return build(odsBuilder, odsState, dstTp, source);
1549 }
1550 
1551 LogicalResult ReinterpretMapOp::verify() {
1552  auto srcStt = getSparseTensorType(getSource());
1553  auto dstStt = getSparseTensorType(getDest());
1554  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1555  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1556 
1557  if (srcLvlTps.size() != dstLvlTps.size())
1558  return emitError("Level rank mismatch between source/dest tensors");
1559 
1560  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1561  if (srcLvlTp != dstLvlTp)
1562  return emitError("Level type mismatch between source/dest tensors");
1563 
1564  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1565  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1566  return emitError("Crd/Pos width mismatch between source/dest tensors");
1567  }
1568 
1569  if (srcStt.getElementType() != dstStt.getElementType())
1570  return emitError("Element type mismatch between source/dest tensors");
1571 
1572  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1573  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1574  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1575  if (srcLvlSz != dstLvlSz) {
1576  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1577  // compatible to <3x4>? For now, we require all the level sizes to be
1578  // *exactly* matched for simplicity.
1579  return emitError("Level size mismatch between source/dest tensors");
1580  }
1581  }
1582 
1583  return success();
1584 }
1585 
1586 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1587  if (getSource().getType() == getDest().getType())
1588  return getSource();
1589 
1590  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1591  // A -> B, B -> A ==> A
1592  if (def.getSource().getType() == getDest().getType())
1593  return def.getSource();
1594  }
1595  return {};
1596 }
1597 
1598 template <typename ToBufferOp>
1599 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1600  OpaqueProperties prop,
1601  RegionRange region,
1603  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1604  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1605  Type elemTp = nullptr;
1606  bool withStride = false;
1607  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1608  elemTp = stt.getPosType();
1609  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1610  std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1611  elemTp = stt.getCrdType();
1612  if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1613  withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1614  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1615  elemTp = stt.getElementType();
1616  }
1617 
1618  assert(elemTp && "unhandled operation.");
1619  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1620  bufShape.push_back(ShapedType::kDynamic);
1621 
1622  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1623  stt.getContext(), ShapedType::kDynamic,
1624  {ShapedType::kDynamic})
1625  : StridedLayoutAttr();
1626  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1627  return success();
1628 }
1629 
1630 LogicalResult ToPositionsOp::verify() {
1631  auto stt = getSparseTensorType(getTensor());
1632  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1633  return emitError("requested level is out of bounds");
1634  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1635  return emitError("unexpected type for positions");
1636  return success();
1637 }
1638 
1639 LogicalResult
1640 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1641  ValueRange ops, DictionaryAttr attr,
1642  OpaqueProperties prop, RegionRange region,
1644  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1645 }
1646 
1647 LogicalResult ToCoordinatesOp::verify() {
1648  auto stt = getSparseTensorType(getTensor());
1649  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1650  return emitError("requested level is out of bounds");
1651  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1652  return emitError("unexpected type for coordinates");
1653  return success();
1654 }
1655 
1656 LogicalResult
1657 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1658  ValueRange ops, DictionaryAttr attr,
1659  OpaqueProperties prop, RegionRange region,
1661  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1662 }
1663 
1664 LogicalResult ToCoordinatesBufferOp::verify() {
1665  auto stt = getSparseTensorType(getTensor());
1666  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1667  return emitError("expected sparse tensor with a COO region");
1668  return success();
1669 }
1670 
1671 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1672  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1673  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1675  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1676  ret);
1677 }
1678 
1679 LogicalResult ToValuesOp::verify() {
1680  auto stt = getSparseTensorType(getTensor());
1681  auto mtp = getMemRefType(getResult());
1682  if (stt.getElementType() != mtp.getElementType())
1683  return emitError("unexpected mismatch in element types");
1684  return success();
1685 }
1686 
1687 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1688  std::optional<Location> loc,
1689  ValueRange ops, DictionaryAttr attr,
1690  OpaqueProperties prop,
1691  RegionRange region,
1693  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1694 }
1695 
1696 LogicalResult ToSliceOffsetOp::verify() {
1697  auto rank = getSlice().getType().getRank();
1698  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1699  return emitError("requested dimension out of bound");
1700  return success();
1701 }
1702 
1703 LogicalResult ToSliceStrideOp::verify() {
1704  auto rank = getSlice().getType().getRank();
1705  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1706  return emitError("requested dimension out of bound");
1707  return success();
1708 }
1709 
1710 LogicalResult GetStorageSpecifierOp::verify() {
1711  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1712  getSpecifier(), getOperation());
1713 }
1714 
1715 template <typename SpecifierOp>
1716 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1717  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1718 }
1719 
1720 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1721  const StorageSpecifierKind kind = getSpecifierKind();
1722  const auto lvl = getLevel();
1723  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1724  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1725  return op.getValue();
1726  return {};
1727 }
1728 
1729 LogicalResult SetStorageSpecifierOp::verify() {
1730  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1731  getSpecifier(), getOperation());
1732 }
1733 
1734 template <class T>
1735 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1736  const char *regionName,
1737  TypeRange inputTypes, Type outputType) {
1738  unsigned numArgs = region.getNumArguments();
1739  unsigned expectedNum = inputTypes.size();
1740  if (numArgs != expectedNum)
1741  return op->emitError() << regionName << " region must have exactly "
1742  << expectedNum << " arguments";
1743 
1744  for (unsigned i = 0; i < numArgs; i++) {
1745  Type typ = region.getArgument(i).getType();
1746  if (typ != inputTypes[i])
1747  return op->emitError() << regionName << " region argument " << (i + 1)
1748  << " type mismatch";
1749  }
1750  Operation *term = region.front().getTerminator();
1751  YieldOp yield = dyn_cast<YieldOp>(term);
1752  if (!yield)
1753  return op->emitError() << regionName
1754  << " region must end with sparse_tensor.yield";
1755  if (!yield.hasSingleResult() ||
1756  yield.getSingleResult().getType() != outputType)
1757  return op->emitError() << regionName << " region yield type mismatch";
1758 
1759  return success();
1760 }
1761 
1762 LogicalResult BinaryOp::verify() {
1763  NamedAttrList attrs = (*this)->getAttrs();
1764  Type leftType = getX().getType();
1765  Type rightType = getY().getType();
1766  Type outputType = getOutput().getType();
1767  Region &overlap = getOverlapRegion();
1768  Region &left = getLeftRegion();
1769  Region &right = getRightRegion();
1770 
1771  // Check correct number of block arguments and return type for each
1772  // non-empty region.
1773  if (!overlap.empty()) {
1774  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1775  TypeRange{leftType, rightType}, outputType)))
1776  return failure();
1777  }
1778  if (!left.empty()) {
1779  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1780  outputType)))
1781  return failure();
1782  } else if (getLeftIdentity()) {
1783  if (leftType != outputType)
1784  return emitError("left=identity requires first argument to have the same "
1785  "type as the output");
1786  }
1787  if (!right.empty()) {
1788  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1789  outputType)))
1790  return failure();
1791  } else if (getRightIdentity()) {
1792  if (rightType != outputType)
1793  return emitError("right=identity requires second argument to have the "
1794  "same type as the output");
1795  }
1796  return success();
1797 }
1798 
1799 LogicalResult UnaryOp::verify() {
1800  Type inputType = getX().getType();
1801  Type outputType = getOutput().getType();
1802 
1803  // Check correct number of block arguments and return type for each
1804  // non-empty region.
1805  Region &present = getPresentRegion();
1806  if (!present.empty()) {
1807  if (failed(verifyNumBlockArgs(this, present, "present",
1808  TypeRange{inputType}, outputType)))
1809  return failure();
1810  }
1811  Region &absent = getAbsentRegion();
1812  if (!absent.empty()) {
1813  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1814  outputType)))
1815  return failure();
1816  // Absent branch can only yield invariant values.
1817  Block *absentBlock = &absent.front();
1818  Block *parent = getOperation()->getBlock();
1819  Value absentVal =
1820  cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1821  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1822  if (arg.getOwner() == parent)
1823  return emitError("absent region cannot yield linalg argument");
1824  } else if (Operation *def = absentVal.getDefiningOp()) {
1825  if (!isa<arith::ConstantOp>(def) &&
1826  (def->getBlock() == absentBlock || def->getBlock() == parent))
1827  return emitError("absent region cannot yield locally computed value");
1828  }
1829  }
1830  return success();
1831 }
1832 
1833 bool ConcatenateOp::needsExtraSort() {
1834  SparseTensorType dstStt = getSparseTensorType(*this);
1835  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1836  return false;
1837 
1838  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1839  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1840  });
1841  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1842  // in all input/output buffers, and all input/output buffers have the same
1843  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1844  // CSC matrices along column).
1845  bool directLowerable =
1846  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1847  return !directLowerable;
1848 }
1849 
1850 LogicalResult ConcatenateOp::verify() {
1851  const auto dstTp = getSparseTensorType(*this);
1852  const Dimension concatDim = getDimension();
1853  const Dimension dimRank = dstTp.getDimRank();
1854 
1855  if (getInputs().size() <= 1)
1856  return emitError("Need at least two tensors to concatenate.");
1857 
1858  if (concatDim >= dimRank)
1859  return emitError(llvm::formatv(
1860  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1861  concatDim, dimRank));
1862 
1863  for (const auto &it : llvm::enumerate(getInputs())) {
1864  const auto i = it.index();
1865  const auto srcTp = getSparseTensorType(it.value());
1866  if (srcTp.hasDynamicDimShape())
1867  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1868  const Dimension srcDimRank = srcTp.getDimRank();
1869  if (srcDimRank != dimRank)
1870  return emitError(
1871  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1872  "from the output tensor (rank={2}).",
1873  i, srcDimRank, dimRank));
1874  }
1875 
1876  for (Dimension d = 0; d < dimRank; d++) {
1877  const Size dstSh = dstTp.getDimShape()[d];
1878  if (d == concatDim) {
1879  if (!ShapedType::isDynamic(dstSh)) {
1880  // If we reach here, then all inputs have static shapes. So we
1881  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1882  // to avoid redundant assertions in the loop.
1883  Size sumSz = 0;
1884  for (const auto src : getInputs())
1885  sumSz += getSparseTensorType(src).getDimShape()[d];
1886  // If all dimension are statically known, the sum of all the input
1887  // dimensions should be equal to the output dimension.
1888  if (sumSz != dstSh)
1889  return emitError(
1890  "The concatenation dimension of the output tensor should be the "
1891  "sum of all the concatenation dimensions of the input tensors.");
1892  }
1893  } else {
1894  Size prev = dstSh;
1895  for (const auto src : getInputs()) {
1896  const auto sh = getSparseTensorType(src).getDimShape()[d];
1897  if (!ShapedType::isDynamic(prev) && sh != prev)
1898  return emitError("All dimensions (expect for the concatenating one) "
1899  "should be equal.");
1900  prev = sh;
1901  }
1902  }
1903  }
1904 
1905  return success();
1906 }
1907 
1908 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1909  Value curSize, Value inBuffer, Value value) {
1910  build(builder, result, curSize, inBuffer, value, Value());
1911 }
1912 
1913 LogicalResult PushBackOp::verify() {
1914  if (Value n = getN()) {
1915  std::optional<int64_t> nValue = getConstantIntValue(n);
1916  if (nValue && nValue.value() < 1)
1917  return emitOpError("n must be not less than 1");
1918  }
1919  return success();
1920 }
1921 
1922 LogicalResult CompressOp::verify() {
1923  const auto stt = getSparseTensorType(getTensor());
1924  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1925  return emitOpError("incorrect number of coordinates");
1926  return success();
1927 }
1928 
1929 void ForeachOp::build(
1930  OpBuilder &builder, OperationState &result, Value tensor,
1931  ValueRange initArgs, AffineMapAttr order,
1933  bodyBuilder) {
1934  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1935  // Builds foreach body.
1936  if (!bodyBuilder)
1937  return;
1938  const auto stt = getSparseTensorType(tensor);
1939  const Dimension dimRank = stt.getDimRank();
1940 
1941  // Starts with `dimRank`-many coordinates.
1942  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1943  // Followed by one value.
1944  blockArgTypes.push_back(stt.getElementType());
1945  // Followed by the reduction variables.
1946  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1947 
1948  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1949 
1950  OpBuilder::InsertionGuard guard(builder);
1951  auto &region = *result.regions.front();
1952  Block *bodyBlock =
1953  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1954  bodyBuilder(builder, result.location,
1955  bodyBlock->getArguments().slice(0, dimRank),
1956  bodyBlock->getArguments()[dimRank],
1957  bodyBlock->getArguments().drop_front(dimRank + 1));
1958 }
1959 
1960 LogicalResult ForeachOp::verify() {
1961  const auto t = getSparseTensorType(getTensor());
1962  const Dimension dimRank = t.getDimRank();
1963  const auto args = getBody()->getArguments();
1964 
1965  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1966  return emitError("Level traverse order does not match tensor's level rank");
1967 
1968  if (dimRank + 1 + getInitArgs().size() != args.size())
1969  return emitError("Unmatched number of arguments in the block");
1970 
1971  if (getNumResults() != getInitArgs().size())
1972  return emitError("Mismatch in number of init arguments and results");
1973 
1974  if (getResultTypes() != getInitArgs().getTypes())
1975  return emitError("Mismatch in types of init arguments and results");
1976 
1977  // Cannot mark this const, because the getters aren't.
1978  auto yield = cast<YieldOp>(getBody()->getTerminator());
1979  if (yield.getNumOperands() != getNumResults() ||
1980  yield.getOperands().getTypes() != getResultTypes())
1981  return emitError("Mismatch in types of yield values and results");
1982 
1983  const auto iTp = IndexType::get(getContext());
1984  for (Dimension d = 0; d < dimRank; d++)
1985  if (args[d].getType() != iTp)
1986  return emitError(
1987  llvm::formatv("Expecting Index type for argument at index {0}", d));
1988 
1989  const auto elemTp = t.getElementType();
1990  const auto valueTp = args[dimRank].getType();
1991  if (elemTp != valueTp)
1992  return emitError(
1993  llvm::formatv("Unmatched element type between input tensor and "
1994  "block argument, expected:{0}, got: {1}",
1995  elemTp, valueTp));
1996  return success();
1997 }
1998 
1999 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2000  if (getSparseTensorEncoding(getInputCoo().getType()) ==
2001  getSparseTensorEncoding(getResultCoo().getType()))
2002  return getInputCoo();
2003 
2004  return {};
2005 }
2006 
2007 LogicalResult ReorderCOOOp::verify() {
2008  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2009  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2010 
2011  if (!srcStt.isCOOType() || !dstStt.isCOOType())
2012  return emitError("Expected COO sparse tensors only");
2013 
2014  if (!srcStt.hasSameDimToLvl(dstStt))
2015  return emitError("Unmatched dim2lvl map between input and result COO");
2016 
2017  if (srcStt.getPosType() != dstStt.getPosType() ||
2018  srcStt.getCrdType() != dstStt.getCrdType() ||
2019  srcStt.getElementType() != dstStt.getElementType())
2020  return emitError("Unmatched storage format between input and result COO");
2021 
2022  return success();
2023 }
2024 
2025 LogicalResult ReduceOp::verify() {
2026  Type inputType = getX().getType();
2027  Region &formula = getRegion();
2028  return verifyNumBlockArgs(this, formula, "reduce",
2029  TypeRange{inputType, inputType}, inputType);
2030 }
2031 
2032 LogicalResult SelectOp::verify() {
2033  Builder b(getContext());
2034  Type inputType = getX().getType();
2035  Type boolType = b.getI1Type();
2036  Region &formula = getRegion();
2037  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2038  boolType);
2039 }
2040 
2041 LogicalResult SortOp::verify() {
2042  AffineMap xPerm = getPermMap();
2043  uint64_t nx = xPerm.getNumDims();
2044  if (nx < 1)
2045  return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2046 
2047  if (!xPerm.isPermutation())
2048  return emitError(
2049  llvm::formatv("Expected a permutation map, got {0}", xPerm));
2050 
2051  // We can't check the size of the buffers when n or buffer dimensions aren't
2052  // compile-time constants.
2053  std::optional<int64_t> cn = getConstantIntValue(getN());
2054  if (!cn)
2055  return success();
2056 
2057  // Verify dimensions.
2058  const auto checkDim = [&](Value v, Size minSize,
2059  const char *message) -> LogicalResult {
2060  const Size sh = getMemRefType(v).getShape()[0];
2061  if (!ShapedType::isDynamic(sh) && sh < minSize)
2062  return emitError(
2063  llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2064  return success();
2065  };
2066  uint64_t n = cn.value();
2067  uint64_t ny = 0;
2068  if (auto nyAttr = getNyAttr())
2069  ny = nyAttr.getInt();
2070  if (failed(checkDim(getXy(), n * (nx + ny),
2071  "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2072  return failure();
2073  for (Value opnd : getYs())
2074  if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2075  return failure();
2076 
2077  return success();
2078 }
2079 
2080 //===----------------------------------------------------------------------===//
2081 // Sparse Tensor Iteration Operations.
2082 //===----------------------------------------------------------------------===//
2083 
2084 IterSpaceType IteratorType::getIterSpaceType() const {
2085  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2086  getHiLvl());
2087 }
2088 
2089 IteratorType IterSpaceType::getIteratorType() const {
2090  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2091 }
2092 
2093 /// Parses a level range in the form "$lo `to` $hi"
2094 /// or simply "$lo" if $hi - $lo = 1
2095 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2096  Level &lvlHi) {
2097  if (parser.parseInteger(lvlLo))
2098  return failure();
2099 
2100  if (succeeded(parser.parseOptionalKeyword("to"))) {
2101  if (parser.parseInteger(lvlHi))
2102  return failure();
2103  } else {
2104  lvlHi = lvlLo + 1;
2105  }
2106 
2107  if (lvlHi <= lvlLo)
2108  return parser.emitError(parser.getNameLoc(),
2109  "expect larger level upper bound than lower bound");
2110 
2111  return success();
2112 }
2113 
2114 /// Parses a level range in the form "$lo `to` $hi"
2115 /// or simply "$lo" if $hi - $lo = 1
2116 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2117  IntegerAttr &lvlHiAttr) {
2118  Level lvlLo, lvlHi;
2119  if (parseLevelRange(parser, lvlLo, lvlHi))
2120  return failure();
2121 
2122  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2123  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2124  return success();
2125 }
2126 
2127 /// Prints a level range in the form "$lo `to` $hi"
2128 /// or simply "$lo" if $hi - $lo = 1
2129 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2130 
2131  if (lo + 1 == hi)
2132  p << lo;
2133  else
2134  p << lo << " to " << hi;
2135 }
2136 
2137 /// Prints a level range in the form "$lo `to` $hi"
2138 /// or simply "$lo" if $hi - $lo = 1
2139 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2140  IntegerAttr lvlHi) {
2141  unsigned lo = lvlLo.getValue().getZExtValue();
2142  unsigned hi = lvlHi.getValue().getZExtValue();
2143  printLevelRange(p, lo, hi);
2144 }
2145 
2146 /// Parses a list of `optional` defined list in the form of
2147 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2148 /// corresponding value is not defined (e.g., to represent an undefined
2149 /// coordinate in the sparse iteration space).
2150 static ParseResult parseOptionalDefinedList(
2151  OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2153  unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2155  unsigned cnt = 0;
2156  ParseResult crdList =
2157  parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2158  if (parser.parseOptionalKeyword("_")) {
2159  if (parser.parseArgument(definedArgs.emplace_back()))
2160  return failure();
2161  definedSet.set(cnt);
2162  }
2163  cnt += 1;
2164  return success();
2165  });
2166 
2167  if (cnt > maxCnt)
2168  return parser.emitError(parser.getNameLoc(),
2169  "parsed more value than expected.");
2170 
2171  if (failed(crdList)) {
2172  return parser.emitError(
2173  parser.getNameLoc(),
2174  "expecting SSA value or \"_\" for level coordinates");
2175  }
2176  assert(definedArgs.size() == definedSet.count());
2177  return success();
2178 }
2179 
2180 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2181  Block::BlockArgListType blocksArgs,
2182  I64BitSet definedSet) {
2183  if (definedSet.empty())
2184  return;
2185 
2186  for (unsigned i = 0; i < size; i++) {
2187  if (definedSet[i]) {
2188  p << blocksArgs.front();
2189  blocksArgs = blocksArgs.drop_front();
2190  } else {
2191  p << "_";
2192  }
2193  if (i != size - 1)
2194  p << ", ";
2195  }
2196  assert(blocksArgs.empty());
2197 }
2198 
2199 static ParseResult
2202  // Parse "at(%crd0, _, ...)"
2203  I64BitSet crdUsedLvlSet;
2204  if (succeeded(parser.parseOptionalKeyword("at")) &&
2205  failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2206  return failure();
2207 
2208  // Always use IndexType for the coordinate.
2209  for (auto &coord : coords)
2210  coord.type = parser.getBuilder().getIndexType();
2211 
2212  // Set the CrdUsedLvl bitset.
2213  state.addAttribute("crdUsedLvls",
2214  parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2215  return success();
2216 }
2217 
2218 static ParseResult
2224 
2225  // Parse "%iters, ... in %spaces, ..."
2226  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2227  parser.parseOperandList(spaces))
2228  return failure();
2229 
2230  if (iterators.size() != spaces.size())
2231  return parser.emitError(
2232  parser.getNameLoc(),
2233  "mismatch in number of sparse iterators and sparse spaces");
2234 
2236  if (failed(parseUsedCoordList(parser, state, coords)))
2237  return failure();
2238  size_t numCrds = coords.size();
2239 
2240  // Parse "iter_args(%arg = %init, ...)"
2241  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2242  if (hasIterArgs)
2243  if (parser.parseAssignmentList(blockArgs, initArgs))
2244  return failure();
2245 
2246  blockArgs.append(coords);
2247 
2248  SmallVector<Type> iterSpaceTps;
2249  // parse ": sparse_tensor.iter_space -> ret"
2250  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2251  return failure();
2252  if (iterSpaceTps.size() != spaces.size())
2253  return parser.emitError(parser.getNameLoc(),
2254  "mismatch in number of iteration space operands "
2255  "and iteration space types");
2256 
2257  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2258  IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2259  if (!spaceTp)
2260  return parser.emitError(parser.getNameLoc(),
2261  "expected sparse_tensor.iter_space type for "
2262  "iteration space operands");
2263  it.type = spaceTp.getIteratorType();
2264  }
2265 
2266  if (hasIterArgs)
2267  if (parser.parseArrowTypeList(state.types))
2268  return failure();
2269 
2270  // Resolves input operands.
2271  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2272  state.operands))
2273  return failure();
2274 
2275  if (hasIterArgs) {
2276  // Strip off leading args that used for coordinates.
2277  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2278  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2279  return parser.emitError(
2280  parser.getNameLoc(),
2281  "mismatch in number of iteration arguments and return values");
2282  }
2283 
2284  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2285  it.type = tp;
2286  if (parser.resolveOperand(init, tp, state.operands))
2287  return failure();
2288  }
2289  }
2290  return success();
2291 }
2292 
2293 static ParseResult
2295  SmallVectorImpl<Value> &spacesVals,
2297 
2298  // Parse "(%spaces, ...)"
2300  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2301  return failure();
2302 
2304  if (failed(parseUsedCoordList(parser, state, coords)))
2305  return failure();
2306  size_t numCrds = coords.size();
2307 
2308  // Parse "iter_args(%arg = %init, ...)"
2310  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2311  if (hasIterArgs)
2312  if (parser.parseAssignmentList(blockArgs, initArgs))
2313  return failure();
2314  blockArgs.append(coords);
2315 
2316  SmallVector<Type> iterSpaceTps;
2317  // parse ": (sparse_tensor.iter_space, ...) -> ret"
2318  if (parser.parseColon() || parser.parseLParen() ||
2319  parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2320  return failure();
2321 
2322  if (iterSpaceTps.size() != spaces.size())
2323  return parser.emitError(parser.getNameLoc(),
2324  "mismatch in number of iteration space operands "
2325  "and iteration space types");
2326 
2327  if (hasIterArgs)
2328  if (parser.parseArrowTypeList(state.types))
2329  return failure();
2330 
2331  // Resolves input sparse iteration spaces.
2332  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2333  spacesVals))
2334  return failure();
2335  state.operands.append(spacesVals);
2336 
2337  if (hasIterArgs) {
2338  // Strip off trailing args that used for coordinates.
2339  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2340  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2341  return parser.emitError(
2342  parser.getNameLoc(),
2343  "mismatch in number of iteration arguments and return values");
2344  }
2345 
2346  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2347  it.type = tp;
2348  if (parser.resolveOperand(init, tp, state.operands))
2349  return failure();
2350  }
2351  }
2352  return success();
2353 }
2354 
2355 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2356  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2357  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2359 
2360  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2361  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2362  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2363  adaptor.getHiLvl()));
2364  return success();
2365 }
2366 
2367 LogicalResult ExtractIterSpaceOp::verify() {
2368  if (getLoLvl() >= getHiLvl())
2369  return emitOpError("expected smaller level low than level high");
2370 
2371  TypedValue<IteratorType> pIter = getParentIter();
2372  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2373  return emitOpError(
2374  "parent iterator should be specified iff level lower bound equals 0");
2375  }
2376 
2377  if (pIter) {
2378  IterSpaceType spaceTp = getExtractedSpace().getType();
2379  if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2380  return emitOpError(
2381  "mismatch in parent iterator encoding and iteration space encoding.");
2382 
2383  if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2384  return emitOpError("parent iterator should be used to extract an "
2385  "iteration space from a consecutive level.");
2386  }
2387 
2388  return success();
2389 }
2390 
2391 LogicalResult ExtractValOp::verify() {
2392  auto stt = getSparseTensorType(getTensor());
2393  auto itTp = getIterator().getType();
2394 
2395  if (stt.getEncoding() != itTp.getEncoding())
2396  return emitOpError("mismatch in tensor encoding and iterator encoding.");
2397 
2398  if (stt.getLvlRank() != itTp.getHiLvl())
2399  return emitOpError("must use last-level iterator to extract values. ");
2400 
2401  return success();
2402 }
2403 
2404 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2406 
2407  LogicalResult matchAndRewrite(IterateOp iterateOp,
2408  PatternRewriter &rewriter) const override {
2409  I64BitSet newUsedLvls(0);
2410  llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2411  for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2412  if (auto crd = iterateOp.getLvlCrd(i)) {
2413  if (crd->getUsers().empty())
2414  toRemove.set(crd->getArgNumber());
2415  else
2416  newUsedLvls.set(i);
2417  }
2418  }
2419 
2420  // All coordinates are used.
2421  if (toRemove.none())
2422  return failure();
2423 
2424  rewriter.startOpModification(iterateOp);
2425  iterateOp.setCrdUsedLvls(newUsedLvls);
2426  iterateOp.getBody()->eraseArguments(toRemove);
2427  rewriter.finalizeOpModification(iterateOp);
2428  return success();
2429  }
2430 };
2431 
2432 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2433  mlir::MLIRContext *context) {
2434  results.add<RemoveUnusedLvlCrds>(context);
2435 }
2436 
2437 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2438  Value iterSpace, ValueRange initArgs) {
2439  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2440  // All ones.
2441  I64BitSet set((1 << rank) - 1);
2442  return build(builder, odsState, iterSpace, initArgs, set);
2443 }
2444 
2445 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2446  Value iterSpace, ValueRange initArgs,
2447  I64BitSet crdUsedLvls) {
2448  OpBuilder::InsertionGuard guard(builder);
2449 
2450  odsState.addOperands(iterSpace);
2451  odsState.addOperands(initArgs);
2452  odsState.getOrAddProperties<Properties>().crdUsedLvls =
2453  builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2454  Region *bodyRegion = odsState.addRegion();
2455  odsState.addTypes(initArgs.getTypes());
2456  Block *bodyBlock = builder.createBlock(bodyRegion);
2457 
2458  // Starts with a list of user-provided loop arguments.
2459  for (Value v : initArgs)
2460  bodyBlock->addArgument(v.getType(), v.getLoc());
2461 
2462  // Follows by a list of used coordinates.
2463  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2464  bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2465 
2466  // Ends with sparse iterator
2467  bodyBlock->addArgument(
2468  llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2469  odsState.location);
2470 }
2471 
2472 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2473  OpAsmParser::Argument iterator;
2475 
2476  SmallVector<OpAsmParser::Argument> iters, iterArgs;
2477  if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2478  return failure();
2479  if (iters.size() != 1)
2480  return parser.emitError(parser.getNameLoc(),
2481  "expected only one iterator/iteration space");
2482 
2483  iterArgs.append(iters);
2484  Region *body = result.addRegion();
2485  if (parser.parseRegion(*body, iterArgs))
2486  return failure();
2487 
2488  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2489 
2490  // Parse the optional attribute list.
2491  if (parser.parseOptionalAttrDict(result.attributes))
2492  return failure();
2493 
2494  return success();
2495 }
2496 
2497 /// Prints the initialization list in the form of
2498 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2499 /// where 'inner' values are assumed to be region arguments and 'outer' values
2500 /// are regular SSA values.
2502  Block::BlockArgListType blocksArgs,
2503  ValueRange initializers,
2504  StringRef prefix = "") {
2505  assert(blocksArgs.size() == initializers.size() &&
2506  "expected same length of arguments and initializers");
2507  if (initializers.empty())
2508  return;
2509 
2510  p << prefix << '(';
2511  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2512  p << std::get<0>(it) << " = " << std::get<1>(it);
2513  });
2514  p << ")";
2515 }
2516 
2517 template <typename SparseLoopOp>
2518 static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2519  if (op.getInitArgs().size() != op.getNumResults()) {
2520  return op.emitOpError(
2521  "mismatch in number of loop-carried values and defined values");
2522  }
2523  if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2524  return op.emitOpError("required out-of-bound coordinates");
2525 
2526  return success();
2527 }
2528 
2529 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2530 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2531 
2532 void IterateOp::print(OpAsmPrinter &p) {
2533  p << " " << getIterator() << " in " << getIterSpace();
2534  if (!getCrdUsedLvls().empty()) {
2535  p << " at(";
2536  printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2537  p << ")";
2538  }
2539  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2540 
2541  p << " : " << getIterSpace().getType() << " ";
2542  if (!getInitArgs().empty())
2543  p.printArrowTypeList(getInitArgs().getTypes());
2544 
2545  p << " ";
2546  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2547  /*printBlockTerminators=*/!getInitArgs().empty());
2548 }
2549 
2550 LogicalResult IterateOp::verifyRegions() {
2551  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2552  return emitOpError("mismatch in iterator and iteration space type");
2553  if (getNumRegionIterArgs() != getNumResults())
2554  return emitOpError(
2555  "mismatch in number of basic block args and defined values");
2556 
2557  auto initArgs = getInitArgs();
2558  auto iterArgs = getRegionIterArgs();
2559  auto yieldVals = getYieldedValues();
2560  auto opResults = getResults();
2561  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2562  opResults.size()})) {
2563  return emitOpError() << "number mismatch between iter args and results.";
2564  }
2565 
2566  for (auto [i, init, iter, yield, ret] :
2567  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2568  if (init.getType() != ret.getType())
2569  return emitOpError() << "types mismatch between " << i
2570  << "th iter operand and defined value";
2571  if (iter.getType() != ret.getType())
2572  return emitOpError() << "types mismatch between " << i
2573  << "th iter region arg and defined value";
2574  if (yield.getType() != ret.getType())
2575  return emitOpError() << "types mismatch between " << i
2576  << "th yield value and defined value";
2577  }
2578 
2579  return success();
2580 }
2581 
2582 /// OpInterfaces' methods implemented by IterateOp.
2583 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2584 
2585 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2586  return getInitArgsMutable();
2587 }
2588 
2589 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2590  return getRegion().getArguments().take_front(getNumRegionIterArgs());
2591 }
2592 
2593 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2594  return cast<sparse_tensor::YieldOp>(
2595  getRegion().getBlocks().front().getTerminator())
2596  .getResultsMutable();
2597 }
2598 
2599 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2600 
2601 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2602  return getInitArgs();
2603 }
2604 
2605 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2607  // Both the operation itself and the region may be branching into the body
2608  // or back into the operation itself.
2609  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2610  // It is possible for loop not to enter the body.
2611  regions.push_back(RegionSuccessor(getResults()));
2612 }
2613 
2614 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2615  ValueRange iterSpaces, ValueRange initArgs,
2616  unsigned numCases) {
2617  unsigned rank =
2618  cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2619  // All ones.
2620  I64BitSet set((1 << rank) - 1);
2621  // Generates all-zero case bits (they only serve as placeholders), which are
2622  // supposed to be overriden later. We need to preallocate all the regions as
2623  // mlir::Region cannot be dynamically added later after the operation is
2624  // created.
2625  SmallVector<int64_t> caseBits(numCases, 0);
2626  ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2627  return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2628  initArgs, set, cases,
2629  /*caseRegionsCount=*/numCases);
2630 }
2631 
2632 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2633 
2634  SmallVector<Value> spaces;
2635  // The block argument list of each regions, it is arranged in the order of
2636  // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2638  if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2639  return failure();
2640 
2641  result.addAttribute("operandSegmentSizes",
2643  {static_cast<int32_t>(spaces.size()),
2644  static_cast<int32_t>(result.types.size())}));
2645 
2646  SmallVector<Attribute> cases;
2647  while (succeeded(parser.parseOptionalKeyword("case"))) {
2648  // Parse one region per case.
2649  I64BitSet definedItSet;
2651  if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2652  spaces.size(), OpAsmParser::Delimiter::None))
2653  return failure();
2654 
2655  cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2656 
2657  for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2658  // Resolve the iterator type based on the iteration space type.
2659  auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2660  definedIts[i].type = spaceTp.getIteratorType();
2661  }
2662  definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2663  Region *body = result.addRegion();
2664  if (parser.parseRegion(*body, definedIts))
2665  return failure();
2666 
2667  CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2668  }
2669 
2670  result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2671 
2672  // Parse the optional attribute list.
2673  if (parser.parseOptionalAttrDict(result.attributes))
2674  return failure();
2675 
2676  return success();
2677 }
2678 
2680  p << " (";
2681  llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2682  p << ")";
2683 
2684  if (!getCrdUsedLvls().empty()) {
2685  p << " at(";
2686  printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2687  p << ")";
2688  }
2689 
2690  printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2691 
2692  p << " : (" << getIterSpaces().getTypes() << ")";
2693  if (!getInitArgs().empty())
2694  p.printArrowTypeList(getInitArgs().getTypes());
2695 
2696  for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2697  p.printNewline();
2698  p << "case ";
2699  printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2700  getRegionDefinedSpace(idx));
2701  p << " ";
2702  p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2703  /*printBlockTerminators=*/!getInitArgs().empty());
2704  }
2705 }
2706 
2707 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2708  return cast<sparse_tensor::YieldOp>(
2709  getRegion(regionIdx).getBlocks().front().getTerminator())
2710  .getResults();
2711 }
2712 
2713 LogicalResult CoIterateOp::verifyRegions() {
2714  for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2715  if (getNumRegionIterArgs() != getNumResults())
2716  return emitOpError(
2717  "mismatch in number of basic block args and defined values");
2718 
2719  auto initArgs = getInitArgs();
2720  auto iterArgs = getRegionIterArgs(r);
2721  auto yieldVals = getYieldedValues(r);
2722  auto opResults = getResults();
2723  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2724  opResults.size()})) {
2725  return emitOpError()
2726  << "number mismatch between iter args and results on " << r
2727  << "th region";
2728  }
2729 
2730  for (auto [i, init, iter, yield, ret] :
2731  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2732  if (init.getType() != ret.getType())
2733  return emitOpError()
2734  << "types mismatch between " << i
2735  << "th iter operand and defined value on " << r << "th region";
2736  if (iter.getType() != ret.getType())
2737  return emitOpError() << "types mismatch between " << i
2738  << "th iter region arg and defined value on " << r
2739  << "th region";
2740  if (yield.getType() != ret.getType())
2741  return emitOpError()
2742  << "types mismatch between " << i
2743  << "th yield value and defined value on " << r << "th region";
2744  }
2745  }
2746 
2747  auto cases = getRegionDefinedSpaces();
2748  llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2749  if (set.size() != getNumRegions())
2750  return emitOpError("contains duplicated cases.");
2751 
2752  return success();
2753 }
2754 
2755 SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2757  I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2758  for (Region &r : getCaseRegions())
2759  if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2760  ret.push_back(&r);
2761 
2762  return ret;
2763 }
2764 
2765 //===----------------------------------------------------------------------===//
2766 // Sparse Tensor Dialect Setups.
2767 //===----------------------------------------------------------------------===//
2768 
2769 /// Materialize a single constant operation from a given attribute value with
2770 /// the desired resultant type.
2772  Attribute value, Type type,
2773  Location loc) {
2774  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2775  return op;
2776  return nullptr;
2777 }
2778 
2779 void SparseTensorDialect::initialize() {
2780  addAttributes<
2781 #define GET_ATTRDEF_LIST
2782 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2783  >();
2784  addTypes<
2785 #define GET_TYPEDEF_LIST
2786 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2787  >();
2788  addOperations<
2789 #define GET_OP_LIST
2790 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2791  >();
2792  declarePromisedInterfaces<
2793  bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2794  NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2795  ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2796 }
2797 
2798 #define GET_OP_CLASSES
2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2800 
2801 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:369
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:645
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:161
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:279
IndexType getIndexType()
Definition: Builders.cpp:53
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
bool isOrderedLT(LevelType lt)
Definition: Enums.h:425
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:442
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
llvm::hash_code hash_value(LevelType lt)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:402
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:788
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
Definition: SparseTensor.h:50
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299