MLIR  22.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
22 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
28 
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 
33 // Forward declarations, following custom print/parsing methods are referenced
34 // by the generated code for SparseTensorTypes.td.
35 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
40 
41 #define GET_TYPEDEF_CLASSES
42 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43 
44 using namespace mlir;
45 using namespace mlir::sparse_tensor;
46 
47 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48 // well.
49 namespace mlir::sparse_tensor {
50 llvm::hash_code hash_value(LevelType lt) {
51  return llvm::hash_value(static_cast<uint64_t>(lt));
52 }
53 } // namespace mlir::sparse_tensor
54 
55 //===----------------------------------------------------------------------===//
56 // Local Convenience Methods.
57 //===----------------------------------------------------------------------===//
58 
59 static constexpr bool acceptBitWidth(unsigned bitWidth) {
60  switch (bitWidth) {
61  case 0:
62  case 8:
63  case 16:
64  case 32:
65  case 64:
66  return true;
67  default:
68  return false;
69  }
70 }
71 
72 static SmallVector<Size>
73 getSparseFieldShape(const SparseTensorEncodingAttr enc,
74  std::optional<ArrayRef<int64_t>> dimShape) {
75  assert(enc);
76  // With only encoding, we can not determine the static shape for leading
77  // batch levels, we therefore return a dynamic shape memref instead.
78  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79  if (dimShape.has_value()) {
80  // If the actual tensor shape is provided, we can then refine the leading
81  // batch dimension.
82  SmallVector<int64_t> lvlShape =
83  enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84  memrefShape.assign(lvlShape.begin(),
85  lvlShape.begin() + enc.getBatchLvlRank());
86  }
87  // Another dynamic dimension to store the sparse level.
88  memrefShape.push_back(ShapedType::kDynamic);
89  return memrefShape;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // SparseTensorDialect StorageLayout.
94 //===----------------------------------------------------------------------===//
95 
96 static constexpr Level kInvalidLevel = -1u;
97 static constexpr Level kInvalidFieldIndex = -1u;
98 static constexpr FieldIndex kDataFieldStartingIdx = 0;
99 
102  LevelType)>
103  callback) const {
104  const auto lvlTypes = enc.getLvlTypes();
105  const Level lvlRank = enc.getLvlRank();
106  SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
108 
109  ArrayRef cooSegsRef = cooSegs;
110  // Per-level storage.
111  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112  const auto lt = lvlTypes[l];
113  if (isWithPosLT(lt)) {
114  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115  return;
116  }
117  if (isWithCrdLT(lt)) {
118  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119  return;
120  }
121  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122  if (!cooSegsRef.front().isSoA) {
123  // AoS COO, all singletons are fused into one memrefs. Skips the entire
124  // COO segement.
125  l = cooSegsRef.front().lvlRange.second;
126  } else {
127  // SoA COO, each singleton level has one memref.
128  l++;
129  }
130  // Expire handled COO segment.
131  cooSegsRef = cooSegsRef.drop_front();
132  } else {
133  // Non COO levels.
134  l++;
135  }
136  }
137  // The values array.
138  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
140  return;
141  // Put metadata at the end.
142  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
144  return;
145 }
146 
148  SparseTensorType stt,
150  LevelType)>
151  callback) {
152  assert(stt.hasEncoding());
153 
154  SmallVector<int64_t> memrefShape =
156 
157  const Type specType = StorageSpecifierType::get(stt.getEncoding());
158  // memref<[batch] x ? x pos> positions
159  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160  // memref<[batch] x ? x crd> coordinates
161  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162  // memref<[batch] x ? x eltType> values
163  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164 
165  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
166  callback](FieldIndex fieldIdx,
167  SparseTensorFieldKind fieldKind,
168  Level lvl, LevelType lt) -> bool {
169  switch (fieldKind) {
171  return callback(specType, fieldIdx, fieldKind, lvl, lt);
173  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178  };
179  llvm_unreachable("unrecognized field kind");
180  });
181 }
182 
183 unsigned StorageLayout::getNumFields() const {
184  unsigned numFields = 0;
186  LevelType) -> bool {
187  numFields++;
188  return true;
189  });
190  return numFields;
191 }
192 
194  unsigned numFields = 0; // one value memref
196  LevelType) -> bool {
197  if (fidx >= kDataFieldStartingIdx)
198  numFields++;
199  return true;
200  });
201  numFields -= 1; // the last field is StorageSpecifier
202  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203  return numFields;
204 }
205 
206 std::pair<FieldIndex, unsigned>
208  std::optional<Level> lvl) const {
209  FieldIndex fieldIdx = kInvalidFieldIndex;
210  unsigned stride = 1;
212  assert(lvl.has_value());
213  const Level cooStart = enc.getAoSCOOStart();
214  const Level lvlRank = enc.getLvlRank();
215  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216  lvl = cooStart;
217  stride = lvlRank - cooStart;
218  }
219  }
220  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
221  SparseTensorFieldKind fKind, Level fLvl,
222  LevelType lt) -> bool {
223  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225  fieldIdx = fIdx;
226  // Returns false to break the iteration.
227  return false;
228  }
229  return true;
230  });
231  assert(fieldIdx != kInvalidFieldIndex);
232  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // SparseTensorDialect Attribute Methods.
237 //===----------------------------------------------------------------------===//
238 
239 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240  return isDynamic(v) ? std::nullopt
241  : std::make_optional(static_cast<uint64_t>(v));
242 }
243 
244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245  return getStatic(getOffset());
246 }
247 
248 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249  return getStatic(getStride());
250 }
251 
252 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253  return getStatic(getSize());
254 }
255 
256 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257  return isDynamic(getOffset()) && isDynamic(getStride()) &&
258  isDynamic(getSize());
259 }
260 
261 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262  return isDynamic(v) ? "?" : std::to_string(v);
263 }
264 
265 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267  os << '(';
268  os << getStaticString(getOffset());
269  os << ", ";
270  os << getStaticString(getSize());
271  os << ", ";
272  os << getStaticString(getStride());
273  os << ')';
274 }
275 
276 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277  print(printer.getStream());
278 }
279 
280 static ParseResult parseOptionalStaticSlice(int64_t &result,
281  AsmParser &parser) {
282  auto parseResult = parser.parseOptionalInteger(result);
283  if (parseResult.has_value()) {
284  if (parseResult.value().succeeded() && result < 0) {
285  parser.emitError(
286  parser.getCurrentLocation(),
287  "expect positive value or ? for slice offset/size/stride");
288  return failure();
289  }
290  return parseResult.value();
291  }
292 
293  // Else, and '?' which represented dynamic slice
294  result = SparseTensorDimSliceAttr::kDynamic;
295  return parser.parseQuestion();
296 }
297 
299  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300 
301  if (failed(parser.parseLParen()) ||
302  failed(parseOptionalStaticSlice(offset, parser)) ||
303  failed(parser.parseComma()) ||
304  failed(parseOptionalStaticSlice(size, parser)) ||
305  failed(parser.parseComma()) ||
306  failed(parseOptionalStaticSlice(stride, parser)) ||
307  failed(parser.parseRParen()))
308  return {};
309 
310  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311  offset, size, stride);
312 }
313 
314 LogicalResult
316  int64_t offset, int64_t size, int64_t stride) {
317  if (!isDynamic(offset) && offset < 0)
318  return emitError() << "expect non-negative value or ? for slice offset";
319  if (!isDynamic(size) && size <= 0)
320  return emitError() << "expect positive value or ? for slice size";
321  if (!isDynamic(stride) && stride <= 0)
322  return emitError() << "expect positive value or ? for slice stride";
323  return success();
324 }
325 
326 SparseTensorEncodingAttr
327 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
330  getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331  getCrdWidth(), getExplicitVal(), getImplicitVal());
332 }
333 
334 SparseTensorEncodingAttr
335 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337 }
338 
339 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340  return withDimToLvl(AffineMap());
341 }
342 
343 SparseTensorEncodingAttr
344 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345  unsigned crdWidth) const {
346  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
348  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349  crdWidth, getExplicitVal(), getImplicitVal());
350 }
351 
352 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353  return withBitWidths(0, 0);
354 }
355 
356 SparseTensorEncodingAttr
357 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
360  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361  getCrdWidth(), explicitVal, getImplicitVal());
362 }
363 
364 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365  return withExplicitVal(Attribute());
366 }
367 
368 SparseTensorEncodingAttr
369 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
372  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373  getCrdWidth(), getExplicitVal(), implicitVal);
374 }
375 
376 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377  return withImplicitVal(Attribute());
378 }
379 
380 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
383  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384  getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
385 }
386 
387 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
389 }
390 
391 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392  ArrayRef<LevelType> lvlTypes = getLvlTypes();
393  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
394  return std::distance(lastBatch, lvlTypes.rend());
395 }
396 
398  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
399 }
400 
401 bool SparseTensorEncodingAttr::isAllOrdered() const {
402  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
403 }
404 
405 Type SparseTensorEncodingAttr::getCrdElemType() const {
406  if (!getImpl())
407  return nullptr;
408  if (getCrdWidth())
409  return IntegerType::get(getContext(), getCrdWidth());
410  return IndexType::get(getContext());
411 }
412 
413 Type SparseTensorEncodingAttr::getPosElemType() const {
414  if (!getImpl())
415  return nullptr;
416  if (getPosWidth())
417  return IntegerType::get(getContext(), getPosWidth());
418  return IndexType::get(getContext());
419 }
420 
421 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422  std::optional<ArrayRef<int64_t>> dimShape) const {
423  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
424  return MemRefType::get(shape, getCrdElemType());
425 }
426 
427 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428  std::optional<ArrayRef<int64_t>> dimShape) const {
429  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
430  return MemRefType::get(shape, getPosElemType());
431 }
432 
433 bool SparseTensorEncodingAttr::isIdentity() const {
434  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435 }
436 
438  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439 }
440 
441 Dimension SparseTensorEncodingAttr::getDimRank() const {
442  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443  const auto dimToLvl = getDimToLvl();
444  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445 }
446 
447 Level SparseTensorEncodingAttr::getLvlRank() const {
448  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449  return getLvlTypes().size();
450 }
451 
452 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453  if (!getImpl())
454  return LevelFormat::Batch;
455  assert(l < getLvlRank() && "Level is out of bounds");
456  return getLvlTypes()[l];
457 }
458 
459 bool SparseTensorEncodingAttr::isSlice() const {
460  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461  return !getDimSlices().empty();
462 }
463 
464 SparseTensorDimSliceAttr
465 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466  assert(isSlice() && "Is not a slice");
467  const auto dimSlices = getDimSlices();
468  assert(dim < dimSlices.size() && "Dimension is out of bounds");
469  return dimSlices[dim];
470 }
471 
472 std::optional<uint64_t>
473 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474  return getDimSlice(dim).getStaticOffset();
475 }
476 
477 std::optional<uint64_t>
478 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479  return getDimSlice(dim).getStaticStride();
480 }
481 
482 std::optional<uint64_t>
483 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484  return getStaticDimSliceOffset(toDim(*this, lvl));
485 }
486 
487 std::optional<uint64_t>
488 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489  return getStaticDimSliceStride(toDim(*this, lvl));
490 }
491 
493 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494  CrdTransDirectionKind dir) const {
495  if (isIdentity())
496  return SmallVector<int64_t>(srcShape);
497 
499  unsigned rank =
500  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501  ret.reserve(rank);
502 
503  if (isPermutation()) {
504  for (unsigned r = 0; r < rank; r++) {
505  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
506  : toLvl(*this, r);
507  ret.push_back(srcShape[trans]);
508  }
509  return ret;
510  }
511 
512  // Handle non-permutation maps.
513  AffineMap transMap =
514  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515 
517  dimRep.reserve(srcShape.size());
518  for (int64_t sz : srcShape) {
519  if (ShapedType::isStatic(sz)) {
520  // Push back the max coordinate for the given dimension/level size.
521  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
522  } else {
523  // A dynamic size, use a AffineDimExpr to symbolize the value.
524  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
525  }
526  };
527 
528  for (AffineExpr exp : transMap.getResults()) {
529  // Do constant propagation on the affine map.
530  AffineExpr evalExp =
531  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
532  // use llvm namespace here to avoid ambiguity
533  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534  ret.push_back(c.getValue() + 1);
535  } else {
536  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
537  mod && mod.getKind() == AffineExprKind::Mod) {
538  // We can still infer a static bound for expressions in form
539  // "d % constant" since d % constant \in [0, constant).
540  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541  ret.push_back(bound.getValue());
542  continue;
543  }
544  }
545  ret.push_back(ShapedType::kDynamic);
546  }
547  }
548  assert(ret.size() == rank);
549  return ret;
550 }
551 
553 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
554  ValueRange crds,
555  CrdTransDirectionKind dir) const {
556  if (!getImpl())
557  return crds;
558 
559  SmallVector<Type> retType(
560  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
561  builder.getIndexType());
562  auto transOp =
563  CrdTranslateOp::create(builder, loc, retType, crds, dir, *this);
564  return transOp.getOutCrds();
565 }
566 
568  // Open "<{" part.
569  if (failed(parser.parseLess()))
570  return {};
571  if (failed(parser.parseLBrace()))
572  return {};
573 
574  // Process the data from the parsed dictionary value into struct-like data.
575  SmallVector<LevelType> lvlTypes;
577  AffineMap dimToLvl = {};
578  AffineMap lvlToDim = {};
579  unsigned posWidth = 0;
580  unsigned crdWidth = 0;
581  Attribute explicitVal;
582  Attribute implicitVal;
583  StringRef attrName;
584  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
585  "explicitVal", "implicitVal"};
586  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
587  // Detect admissible keyword.
588  auto *it = find(keys, attrName);
589  if (it == keys.end()) {
590  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
591  return {};
592  }
593  unsigned keyWordIndex = it - keys.begin();
594  // Consume the `=` after keys
595  if (failed(parser.parseEqual()))
596  return {};
597  // Dispatch on keyword.
598  switch (keyWordIndex) {
599  case 0: { // map
600  ir_detail::DimLvlMapParser cParser(parser);
601  auto res = cParser.parseDimLvlMap();
602  if (failed(res))
603  return {};
604  const auto &dlm = *res;
605 
606  const Level lvlRank = dlm.getLvlRank();
607  for (Level lvl = 0; lvl < lvlRank; lvl++)
608  lvlTypes.push_back(dlm.getLvlType(lvl));
609 
610  const Dimension dimRank = dlm.getDimRank();
611  for (Dimension dim = 0; dim < dimRank; dim++)
612  dimSlices.push_back(dlm.getDimSlice(dim));
613  // NOTE: the old syntax requires an all-or-nothing approach to
614  // `dimSlices`; therefore, if any slice actually exists then we need
615  // to convert null-DSA into default/nop DSA.
616  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617  return static_cast<bool>(slice.getImpl());
618  };
619  if (llvm::any_of(dimSlices, isDefined)) {
620  const auto defaultSlice =
622  for (Dimension dim = 0; dim < dimRank; dim++)
623  if (!isDefined(dimSlices[dim]))
624  dimSlices[dim] = defaultSlice;
625  } else {
626  dimSlices.clear();
627  }
628 
629  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
630  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
631  break;
632  }
633  case 1: { // posWidth
634  Attribute attr;
635  if (failed(parser.parseAttribute(attr)))
636  return {};
637  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
638  if (!intAttr) {
639  parser.emitError(parser.getNameLoc(),
640  "expected an integral position bitwidth");
641  return {};
642  }
643  posWidth = intAttr.getInt();
644  break;
645  }
646  case 2: { // crdWidth
647  Attribute attr;
648  if (failed(parser.parseAttribute(attr)))
649  return {};
650  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
651  if (!intAttr) {
652  parser.emitError(parser.getNameLoc(),
653  "expected an integral index bitwidth");
654  return {};
655  }
656  crdWidth = intAttr.getInt();
657  break;
658  }
659  case 3: { // explicitVal
660  Attribute attr;
661  if (failed(parser.parseAttribute(attr)))
662  return {};
663  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
664  explicitVal = result;
665  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
666  explicitVal = result;
667  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668  explicitVal = result;
669  } else {
670  parser.emitError(parser.getNameLoc(),
671  "expected a numeric value for explicitVal");
672  return {};
673  }
674  break;
675  }
676  case 4: { // implicitVal
677  Attribute attr;
678  if (failed(parser.parseAttribute(attr)))
679  return {};
680  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
681  implicitVal = result;
682  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
683  implicitVal = result;
684  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685  implicitVal = result;
686  } else {
687  parser.emitError(parser.getNameLoc(),
688  "expected a numeric value for implicitVal");
689  return {};
690  }
691  break;
692  }
693  } // switch
694  // Only last item can omit the comma.
695  if (parser.parseOptionalComma().failed())
696  break;
697  }
698 
699  // Close "}>" part.
700  if (failed(parser.parseRBrace()))
701  return {};
702  if (failed(parser.parseGreater()))
703  return {};
704 
705  // Construct struct-like storage for attribute.
706  if (!lvlToDim || lvlToDim.isEmpty()) {
707  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
708  }
709  return parser.getChecked<SparseTensorEncodingAttr>(
710  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711  explicitVal, implicitVal, dimSlices);
712 }
713 
714 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
715  auto map = static_cast<AffineMap>(getDimToLvl());
716  // Empty affine map indicates identity map
717  if (!map)
718  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
719  printer << "<{ map = ";
720  printSymbols(map, printer);
721  printer << '(';
722  printDimensions(map, printer, getDimSlices());
723  printer << ") -> (";
724  printLevels(map, printer, getLvlTypes());
725  printer << ')';
726  // Print remaining members only for non-default values.
727  if (getPosWidth())
728  printer << ", posWidth = " << getPosWidth();
729  if (getCrdWidth())
730  printer << ", crdWidth = " << getCrdWidth();
731  if (getExplicitVal()) {
732  printer << ", explicitVal = " << getExplicitVal();
733  }
734  if (getImplicitVal())
735  printer << ", implicitVal = " << getImplicitVal();
736  printer << " }>";
737 }
738 
739 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
740  AsmPrinter &printer) const {
741  if (map.getNumSymbols() == 0)
742  return;
743  printer << '[';
744  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
745  printer << 's' << i << ", ";
746  if (map.getNumSymbols() >= 1)
747  printer << 's' << map.getNumSymbols() - 1;
748  printer << ']';
749 }
750 
751 void SparseTensorEncodingAttr::printDimensions(
752  AffineMap &map, AsmPrinter &printer,
753  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
754  if (!dimSlices.empty()) {
755  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
756  printer << 'd' << i << " : " << dimSlices[i] << ", ";
757  if (map.getNumDims() >= 1) {
758  printer << 'd' << map.getNumDims() - 1 << " : "
759  << dimSlices[map.getNumDims() - 1];
760  }
761  } else {
762  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
763  printer << 'd' << i << ", ";
764  if (map.getNumDims() >= 1)
765  printer << 'd' << map.getNumDims() - 1;
766  }
767 }
768 
769 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
770  ArrayRef<LevelType> lvlTypes) const {
771  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
772  map.getResult(i).print(printer.getStream());
773  printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
774  }
775  if (map.getNumResults() >= 1) {
776  auto lastIndex = map.getNumResults() - 1;
777  map.getResult(lastIndex).print(printer.getStream());
778  printer << " : " << toMLIRString(lvlTypes[lastIndex]);
779  }
780 }
781 
784  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
785  unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
787  if (!acceptBitWidth(posWidth))
788  return emitError() << "unexpected position bitwidth: " << posWidth;
789  if (!acceptBitWidth(crdWidth))
790  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
791 
792  // Verify every COO segment.
793  auto *it = llvm::find_if(lvlTypes, isSingletonLT);
794  while (it != lvlTypes.end()) {
795  if (it == lvlTypes.begin() ||
797  return emitError() << "expected compressed or loose_compressed level "
798  "before singleton level";
799 
800  auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
801  if (!std::all_of(it, curCOOEnd, isSingletonLT))
802  return emitError() << "expected all singleton lvlTypes "
803  "following a singleton level";
804  // We can potentially support mixed SoA/AoS singleton levels.
805  if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
806  return it->isa<LevelPropNonDefault::SoA>() ==
808  })) {
809  return emitError() << "expected all singleton lvlTypes stored in the "
810  "same memory layout (SoA vs AoS).";
811  }
812  it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
813  }
814 
815  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
816  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
817  return emitError() << "Batch lvlType can only be leading levels.";
818 
819  // SoA property can only be applied on singleton level.
820  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
821  return lt.isa<LevelPropNonDefault::SoA>();
822  });
823  if (llvm::any_of(soaLvls, [](LevelType lt) {
824  return !lt.isa<LevelFormat::Singleton>();
825  })) {
826  return emitError() << "SoA is only applicable to singleton lvlTypes.";
827  }
828 
829  // TODO: audit formats that actually are supported by backend.
830  if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
831  it != std::end(lvlTypes)) {
832  if (it != lvlTypes.end() - 1)
833  return emitError() << "expected n_out_of_m to be the last level type";
834  if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
835  return emitError() << "expected all dense lvlTypes "
836  "before a n_out_of_m level";
837  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
838  if (!isBlockSparsity(dimToLvl)) {
839  return emitError()
840  << "expected 1xm block structure for n_out_of_m level";
841  }
842  auto sizes = getBlockSize(dimToLvl);
843  unsigned coefficient = 0;
844  for (const auto &elem : sizes) {
845  if (elem != 0) {
846  if (elem != coefficient && coefficient != 0) {
847  return emitError() << "expected only one blocked level "
848  "with the same coefficients";
849  }
850  coefficient = elem;
851  }
852  }
853  if (coefficient != getM(*it)) {
854  return emitError() << "expected coeffiencts of Affine expressions "
855  "to be equal to m of n_out_of_m level";
856  }
857  }
858  }
859  // Before we can check that the level-rank is consistent/coherent
860  // across all fields, we need to define it. The source-of-truth for
861  // the `getLvlRank` method is the length of the level-types array,
862  // since it must always be provided and have full rank; therefore we
863  // use that same source-of-truth here.
864  const Level lvlRank = lvlTypes.size();
865  if (lvlRank == 0)
866  return emitError() << "expected a non-empty array for lvlTypes";
867  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
868  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
869  if (dimToLvl) {
870  if (dimToLvl.getNumResults() != lvlRank)
871  return emitError()
872  << "level-rank mismatch between dimToLvl and lvlTypes: "
873  << dimToLvl.getNumResults() << " != " << lvlRank;
874  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
875  // Symbols can't be inferred but are acceptable.
876  if (!inferRes && dimToLvl.getNumSymbols() == 0)
877  return emitError() << "failed to infer lvlToDim from dimToLvl";
878  if (lvlToDim && (inferRes != lvlToDim))
879  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
880  if (dimRank > lvlRank)
881  return emitError() << "unexpected dimToLvl mapping from " << dimRank
882  << " to " << lvlRank;
883  }
884  if (!dimSlices.empty()) {
885  if (dimSlices.size() != dimRank)
886  return emitError()
887  << "dimension-rank mismatch between dimSlices and dimToLvl: "
888  << dimSlices.size() << " != " << dimRank;
889  // Compiler support for `dimSlices` currently requires that the two
890  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
891  if (dimRank != lvlRank)
892  return emitError()
893  << "dimSlices expected dimension-rank to match level-rank: "
894  << dimRank << " != " << lvlRank;
895  }
896  return success();
897 }
898 
899 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
900  ArrayRef<Size> dimShape, Type elementType,
902  // Check structural integrity. In particular, this ensures that the
903  // level-rank is coherent across all the fields.
904  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
905  getPosWidth(), getCrdWidth(), getExplicitVal(),
906  getImplicitVal(), getDimSlices())))
907  return failure();
908  // Check integrity with tensor type specifics. In particular, we
909  // need only check that the dimension-rank of the tensor agrees with
910  // the dimension-rank of the encoding.
911  const Dimension dimRank = dimShape.size();
912  if (dimRank == 0)
913  return emitError() << "expected non-scalar sparse tensor";
914  if (getDimRank() != dimRank)
915  return emitError()
916  << "dimension-rank mismatch between encoding and tensor shape: "
917  << getDimRank() << " != " << dimRank;
918  if (auto expVal = getExplicitVal()) {
919  Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
920  if (attrType != elementType) {
921  return emitError() << "explicit value type mismatch between encoding and "
922  << "tensor element type: " << attrType
923  << " != " << elementType;
924  }
925  }
926  if (auto impVal = getImplicitVal()) {
927  Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
928  if (attrType != elementType) {
929  return emitError() << "implicit value type mismatch between encoding and "
930  << "tensor element type: " << attrType
931  << " != " << elementType;
932  }
933  // Currently, we only support zero as the implicit value.
934  auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
935  auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
936  auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
937  if ((impFVal && impFVal.getValue().isNonZero()) ||
938  (impIntVal && !impIntVal.getValue().isZero()) ||
939  (impComplexVal && (impComplexVal.getImag().isNonZero() ||
940  impComplexVal.getReal().isNonZero()))) {
941  return emitError() << "implicit value must be zero";
942  }
943  }
944  return success();
945 }
946 
947 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
948  SmallVector<COOSegment> coo = getCOOSegments();
949  assert(coo.size() == 1 || coo.empty());
950  if (!coo.empty() && coo.front().isAoS()) {
951  return coo.front().lvlRange.first;
952  }
953  return getLvlRank();
954 }
955 
957 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
959  if (getLvlRank() <= 1)
960  return ret;
961 
962  ArrayRef<LevelType> lts = getLvlTypes();
963  Level l = 0;
964  while (l < getLvlRank()) {
965  auto lt = lts[l];
967  auto cur = lts.begin() + l;
968  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
969  return !lt.isa<LevelFormat::Singleton>();
970  });
971  unsigned cooLen = std::distance(cur, end);
972  if (cooLen > 1) {
973  // To support mixed SoA/AoS COO, we should break the segment when the
974  // storage scheme changes, for now we faithfully assume that all
975  // consecutive singleton levels have the same storage format as verified
976  // STEA.
977  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
978  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
979  }
980  l += cooLen;
981  } else {
982  l++;
983  }
984  }
985  return ret;
986 }
987 
988 //===----------------------------------------------------------------------===//
989 // SparseTensorType Methods.
990 //===----------------------------------------------------------------------===//
991 
993  bool isUnique) const {
994  if (!hasEncoding())
995  return false;
996  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
997  return false;
998  for (Level l = startLvl + 1; l < lvlRank; ++l)
999  if (!isSingletonLvl(l))
1000  return false;
1001  // If isUnique is true, then make sure that the last level is unique,
1002  // that is, when lvlRank == 1, the only compressed level is unique,
1003  // and when lvlRank > 1, the last singleton is unique.
1004  return !isUnique || isUniqueLvl(lvlRank - 1);
1005 }
1006 
1007 RankedTensorType
1009  SmallVector<LevelType> lvlTypes;
1010  lvlTypes.reserve(lvlRank);
1011  // A non-unique compressed level at beginning (unless this is
1012  // also the last level, then it is unique).
1013  lvlTypes.push_back(
1014  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1015  if (lvlRank > 1) {
1016  // Followed by n-2 non-unique singleton levels.
1017  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1018  *buildLevelType(LevelFormat::Singleton, ordered, false));
1019  // Ends by a unique singleton level.
1020  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1021  }
1022  auto enc = SparseTensorEncodingAttr::get(
1023  getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1024  getCrdWidth(), getExplicitVal(), getImplicitVal());
1025  return RankedTensorType::get(getDimShape(), getElementType(), enc);
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // Convenience Methods.
1030 //===----------------------------------------------------------------------===//
1031 
1032 SparseTensorEncodingAttr
1034  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1035  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1036  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1037  return mdtp.getEncoding();
1038  return nullptr;
1039 }
1040 
1042  MLIRContext *context) {
1043  auto map = static_cast<AffineMap>(dimToLvl);
1044  AffineMap lvlToDim;
1045  // Return an empty lvlToDim when inference is not successful.
1046  if (!map || map.getNumSymbols() != 0) {
1047  lvlToDim = AffineMap();
1048  } else if (map.isPermutation()) {
1049  lvlToDim = inversePermutation(map);
1050  } else if (isBlockSparsity(map)) {
1051  lvlToDim = inverseBlockSparsity(map, context);
1052  }
1053  return lvlToDim;
1054 }
1055 
1057  MLIRContext *context) {
1058  SmallVector<AffineExpr> lvlExprs;
1059  auto numLvls = dimToLvl.getNumResults();
1060  lvlExprs.reserve(numLvls);
1061  // lvlExprComponents stores information of the floordiv and mod operations
1062  // applied to the same dimension, so as to build the lvlToDim map.
1063  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1064  for (unsigned i = 0, n = numLvls; i < n; i++) {
1065  auto result = dimToLvl.getResult(i);
1066  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1067  if (result.getKind() == AffineExprKind::FloorDiv) {
1068  // Position of the dimension in dimToLvl.
1069  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1070  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1071  "expected only one floordiv for each dimension");
1072  SmallVector<AffineExpr, 3> components;
1073  // Level variable for floordiv.
1074  components.push_back(getAffineDimExpr(i, context));
1075  // Multiplier.
1076  components.push_back(binOp.getRHS());
1077  // Map key is the position of the dimension.
1078  lvlExprComponents[pos] = components;
1079  } else if (result.getKind() == AffineExprKind::Mod) {
1080  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1081  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1082  "expected floordiv before mod");
1083  // Add level variable for mod to the same vector
1084  // of the corresponding floordiv.
1085  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1086  } else {
1087  assert(false && "expected floordiv or mod");
1088  }
1089  } else {
1090  lvlExprs.push_back(getAffineDimExpr(i, context));
1091  }
1092  }
1093  // Build lvlExprs from lvlExprComponents.
1094  // For example, for il = i floordiv 2 and ii = i mod 2, the components
1095  // would be [il, 2, ii]. It could be used to build the AffineExpr
1096  // i = il * 2 + ii in lvlToDim.
1097  for (auto &components : lvlExprComponents) {
1098  assert(components.second.size() == 3 &&
1099  "expected 3 components to build lvlExprs");
1100  auto mulOp = getAffineBinaryOpExpr(
1101  AffineExprKind::Mul, components.second[0], components.second[1]);
1102  auto addOp =
1103  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1104  lvlExprs.push_back(addOp);
1105  }
1106  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1107 }
1108 
1110  assert(isBlockSparsity(dimToLvl) &&
1111  "expected dimToLvl to be block sparsity for calling getBlockSize");
1112  SmallVector<unsigned> blockSize;
1113  for (auto result : dimToLvl.getResults()) {
1114  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1115  if (result.getKind() == AffineExprKind::Mod) {
1116  blockSize.push_back(
1117  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1118  }
1119  } else {
1120  blockSize.push_back(0);
1121  }
1122  }
1123  return blockSize;
1124 }
1125 
1127  if (!dimToLvl)
1128  return false;
1129  std::map<unsigned, int64_t> coeffientMap;
1130  bool hasBlock = false;
1131  for (auto result : dimToLvl.getResults()) {
1132  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1133  // Check for "dim op const".
1134  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1135  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1136  if (!dimOp || !conOp || conOp.getValue() <= 0)
1137  return false;
1138  // Inspect "dim / const" or "dim % const".
1139  auto pos = dimOp.getPosition();
1140  if (binOp.getKind() == AffineExprKind::FloorDiv) {
1141  // Expect only one floordiv for each dimension.
1142  auto [it, inserted] = coeffientMap.try_emplace(pos);
1143  if (!inserted)
1144  return false;
1145  // Record coefficient of the floordiv.
1146  it->second = conOp.getValue();
1147  } else if (binOp.getKind() == AffineExprKind::Mod) {
1148  // Expect floordiv before mod.
1149  auto it = coeffientMap.find(pos);
1150  if (it == coeffientMap.end())
1151  return false;
1152  // Expect mod to have the same coefficient as floordiv.
1153  if (conOp.getValue() != it->second)
1154  return false;
1155  hasBlock = true;
1156  } else {
1157  return false;
1158  }
1159  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1160  auto pos = dimOp.getPosition();
1161  // Expect dim to be unset.
1162  if (!coeffientMap.try_emplace(pos, 0).second)
1163  return false;
1164  } else {
1165  return false;
1166  }
1167  }
1168  return hasBlock;
1169 }
1170 
1172  auto hasNonIdentityMap = [](Value v) {
1173  auto stt = tryGetSparseTensorType(v);
1174  return stt && !stt->isIdentity();
1175  };
1176 
1177  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1178  llvm::any_of(op->getResults(), hasNonIdentityMap);
1179 }
1180 
1181 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1182  if (enc) {
1183  assert(enc.isPermutation() && "Non permutation map not supported");
1184  if (const auto dimToLvl = enc.getDimToLvl())
1185  return dimToLvl.getDimPosition(l);
1186  }
1187  return l;
1188 }
1189 
1190 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1191  if (enc) {
1192  assert(enc.isPermutation() && "Non permutation map not supported");
1193  if (const auto lvlToDim = enc.getLvlToDim())
1194  return lvlToDim.getDimPosition(d);
1195  }
1196  return d;
1197 }
1198 
1199 /// We normalized sparse tensor encoding attribute by always using
1200 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1201 /// as other variants) lead to the same storage specifier type, and stripping
1202 /// irrelevant fields that do not alter the sparse tensor memory layout.
1203 static SparseTensorEncodingAttr
1204 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1206  for (auto lt : enc.getLvlTypes())
1207  lts.push_back(lt.stripStorageIrrelevantProperties());
1208 
1210  enc.getContext(), lts,
1211  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1212  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1213  // Always use `index` for memSize and lvlSize instead of reusing
1214  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1215  // value for different bitwidth, it also avoids casting between index and
1216  // integer (returned by DimOp)
1217  0, 0,
1218  Attribute(), // explicitVal (irrelevant to storage specifier)
1219  Attribute(), // implicitVal (irrelevant to storage specifier)
1220  enc.getDimSlices());
1221 }
1222 
1223 StorageSpecifierType
1224 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1225  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1226 }
1227 
1228 StorageSpecifierType
1229 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1230  MLIRContext *ctx,
1231  SparseTensorEncodingAttr encoding) {
1232  return Base::getChecked(emitError, ctx,
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // SparseTensorDialect Operations.
1238 //===----------------------------------------------------------------------===//
1239 
1240 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1241  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1242 }
1243 
1244 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1245  const Type etp = getMemRefType(mem).getElementType();
1246  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1247 }
1248 
1249 static LogicalResult verifySparsifierGetterSetter(
1250  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1252  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1253  return op->emitError(
1254  "redundant level argument for querying value memory size");
1255  }
1256 
1257  const auto enc = md.getType().getEncoding();
1258  const Level lvlRank = enc.getLvlRank();
1259 
1260  if (mdKind == StorageSpecifierKind::DimOffset ||
1261  mdKind == StorageSpecifierKind::DimStride)
1262  if (!enc.isSlice())
1263  return op->emitError("requested slice data on non-slice tensor");
1264 
1265  if (mdKind != StorageSpecifierKind::ValMemSize) {
1266  if (!lvl)
1267  return op->emitError("missing level argument");
1268 
1269  const Level l = lvl.value();
1270  if (l >= lvlRank)
1271  return op->emitError("requested level is out of bounds");
1272 
1273  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1274  return op->emitError(
1275  "requested position memory size on a singleton level");
1276  }
1277  return success();
1278 }
1279 
1281  switch (kind) {
1283  return stt.getCrdType();
1285  return stt.getPosType();
1287  return stt.getElementType();
1289  return nullptr;
1290  }
1291  llvm_unreachable("Unrecognizable FieldKind");
1292 }
1293 
1294 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1295  SparseTensorType stt,
1296  RankedTensorType valTp,
1297  TypeRange lvlTps) {
1298  if (requiresStaticShape && !stt.hasStaticDimShape())
1299  return op->emitError("the sparse-tensor must have static shape");
1300  if (!stt.hasEncoding())
1301  return op->emitError("the sparse-tensor must have an encoding attribute");
1302 
1303  // Verifies the trailing COO.
1304  Level cooStartLvl = stt.getAoSCOOStart();
1305  if (cooStartLvl < stt.getLvlRank()) {
1306  // We only supports trailing COO for now, must be the last input.
1307  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1308  // The coordinates should be in shape of <? x rank>
1309  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1310  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1311  return op->emitError("input/output trailing COO level-ranks don't match");
1312  }
1313  }
1314 
1315  // Verifies that all types match.
1316  StorageLayout layout(stt.getEncoding());
1317  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1318  return op->emitError("inconsistent number of fields between input/output");
1319 
1320  unsigned idx = 0;
1321  bool misMatch = false;
1322  layout.foreachField([&idx, &misMatch, stt, valTp,
1323  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1324  Level lvl, LevelType lt) -> bool {
1326  return true;
1327 
1328  Type inputTp = nullptr;
1329  if (fKind == SparseTensorFieldKind::ValMemRef) {
1330  inputTp = valTp;
1331  } else {
1332  assert(fid == idx && stt.getLvlType(lvl) == lt);
1333  inputTp = lvlTps[idx++];
1334  }
1335  // The input element type and expected element type should match.
1336  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1337  Type expElemTp = getFieldElemType(stt, fKind);
1338  if (inpElemTp != expElemTp) {
1339  misMatch = true;
1340  return false; // to terminate the iteration
1341  }
1342  return true;
1343  });
1344 
1345  if (misMatch)
1346  return op->emitError("input/output element-types don't match");
1347  return success();
1348 }
1349 
1350 LogicalResult AssembleOp::verify() {
1351  RankedTensorType valuesTp = getValues().getType();
1352  const auto lvlsTp = getLevels().getTypes();
1353  const auto resTp = getSparseTensorType(getResult());
1354  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1355 }
1356 
1357 LogicalResult DisassembleOp::verify() {
1358  if (getOutValues().getType() != getRetValues().getType())
1359  return emitError("output values and return value type mismatch");
1360 
1361  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1362  if (ot.getType() != rt.getType())
1363  return emitError("output levels and return levels type mismatch");
1364 
1365  RankedTensorType valuesTp = getRetValues().getType();
1366  const auto lvlsTp = getRetLevels().getTypes();
1367  const auto srcTp = getSparseTensorType(getTensor());
1368  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1369 }
1370 
1371 LogicalResult ConvertOp::verify() {
1372  RankedTensorType tp1 = getSource().getType();
1373  RankedTensorType tp2 = getDest().getType();
1374  if (tp1.getRank() != tp2.getRank())
1375  return emitError("unexpected conversion mismatch in rank");
1376  auto dstEnc =
1377  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1378  if (dstEnc && dstEnc.isSlice())
1379  return emitError("cannot convert to a sparse tensor slice");
1380 
1381  auto shape1 = tp1.getShape();
1382  auto shape2 = tp2.getShape();
1383  // Accept size matches between the source and the destination type
1384  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1385  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1386  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1387  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1388  return emitError("unexpected conversion mismatch in dimension ") << d;
1389  return success();
1390 }
1391 
1392 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1393  if (getType() == getSource().getType())
1394  return getSource();
1395  return {};
1396 }
1397 
1398 bool ConvertOp::needsExtraSort() {
1399  SparseTensorType srcStt = getSparseTensorType(getSource());
1400  SparseTensorType dstStt = getSparseTensorType(getDest());
1401 
1402  // We do not need an extra sort when returning unordered sparse tensors or
1403  // dense tensor since dense tensor support random access.
1404  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1405  return false;
1406 
1407  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1408  srcStt.hasSameDimToLvl(dstStt)) {
1409  return false;
1410  }
1411 
1412  // Source and dest tensors are ordered in different ways. We only do direct
1413  // dense to sparse conversion when the dense input is defined by a sparse
1414  // constant. Note that we can theoretically always directly convert from dense
1415  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1416  // performance.
1417  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1418  if (isa<SparseElementsAttr>(constOp.getValue()))
1419  return false;
1420 
1421  return true;
1422 }
1423 
1424 LogicalResult CrdTranslateOp::verify() {
1425  uint64_t inRank = getEncoder().getLvlRank();
1426  uint64_t outRank = getEncoder().getDimRank();
1427 
1428  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1429  std::swap(inRank, outRank);
1430 
1431  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1432  return emitError("Coordinate rank mismatch with encoding");
1433 
1434  return success();
1435 }
1436 
1437 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1438  SmallVectorImpl<OpFoldResult> &results) {
1439  if (getEncoder().isIdentity()) {
1440  results.assign(getInCrds().begin(), getInCrds().end());
1441  return success();
1442  }
1443  if (getEncoder().isPermutation()) {
1444  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1445  ? getEncoder().getDimToLvl()
1446  : getEncoder().getLvlToDim();
1447  for (AffineExpr exp : perm.getResults())
1448  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1449  return success();
1450  }
1451 
1452  // Fuse dim2lvl/lvl2dim pairs.
1453  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1454  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1455  return v.getDefiningOp() == def;
1456  });
1457  if (!sameDef)
1458  return failure();
1459 
1460  bool oppositeDir = def.getDirection() != getDirection();
1461  bool sameOracle =
1462  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1463  bool sameCount = def.getNumResults() == getInCrds().size();
1464  if (!oppositeDir || !sameOracle || !sameCount)
1465  return failure();
1466 
1467  // The definition produces the coordinates in the same order as the input
1468  // coordinates.
1469  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1470  [](auto valuePair) {
1471  auto [lhs, rhs] = valuePair;
1472  return lhs == rhs;
1473  });
1474 
1475  if (!sameOrder)
1476  return failure();
1477  // l1 = dim2lvl (lvl2dim l0)
1478  // ==> l0
1479  results.append(def.getInCrds().begin(), def.getInCrds().end());
1480  return success();
1481 }
1482 
1483 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1484  int64_t index) {
1485  Value val = arith::ConstantIndexOp::create(builder, state.location, index);
1486  return build(builder, state, source, val);
1487 }
1488 
1489 LogicalResult LvlOp::verify() {
1490  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1491  auto stt = getSparseTensorType(getSource());
1492  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1493  return emitError(
1494  "Level index exceeds the rank of the input sparse tensor");
1495  }
1496  return success();
1497 }
1498 
1499 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1500  return getConstantIntValue(getIndex());
1501 }
1502 
1503 Speculation::Speculatability LvlOp::getSpeculatability() {
1504  auto constantIndex = getConstantLvlIndex();
1505  if (!constantIndex)
1507 
1508  assert(constantIndex <
1509  cast<RankedTensorType>(getSource().getType()).getRank());
1511 }
1512 
1513 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1514  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1515  if (!lvlIndex)
1516  return {};
1517 
1518  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1519  auto stt = getSparseTensorType(getSource());
1520  if (lvl >= stt.getLvlRank()) {
1521  // Follows the same convention used by tensor.dim operation. Out of bound
1522  // indices produce undefined behavior but are still valid IR. Don't choke on
1523  // them.
1524  return {};
1525  }
1526 
1527  // Helper lambda to build an IndexAttr.
1528  auto getIndexAttr = [this](int64_t lvlSz) {
1529  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1530  };
1531 
1532  SmallVector<Size> lvlShape = stt.getLvlShape();
1533  if (ShapedType::isStatic(lvlShape[lvl]))
1534  return getIndexAttr(lvlShape[lvl]);
1535 
1536  return {};
1537 }
1538 
1539 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1540  SparseTensorEncodingAttr dstEnc, Value source) {
1541  auto srcStt = getSparseTensorType(source);
1542  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1543  SmallVector<int64_t> dstDimShape =
1544  dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1545  auto dstTp =
1546  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1547  return build(odsBuilder, odsState, dstTp, source);
1548 }
1549 
1550 LogicalResult ReinterpretMapOp::verify() {
1551  auto srcStt = getSparseTensorType(getSource());
1552  auto dstStt = getSparseTensorType(getDest());
1553  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1554  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1555 
1556  if (srcLvlTps.size() != dstLvlTps.size())
1557  return emitError("Level rank mismatch between source/dest tensors");
1558 
1559  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1560  if (srcLvlTp != dstLvlTp)
1561  return emitError("Level type mismatch between source/dest tensors");
1562 
1563  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1564  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1565  return emitError("Crd/Pos width mismatch between source/dest tensors");
1566  }
1567 
1568  if (srcStt.getElementType() != dstStt.getElementType())
1569  return emitError("Element type mismatch between source/dest tensors");
1570 
1571  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1572  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1573  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1574  if (srcLvlSz != dstLvlSz) {
1575  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1576  // compatible to <3x4>? For now, we require all the level sizes to be
1577  // *exactly* matched for simplicity.
1578  return emitError("Level size mismatch between source/dest tensors");
1579  }
1580  }
1581 
1582  return success();
1583 }
1584 
1585 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1586  if (getSource().getType() == getDest().getType())
1587  return getSource();
1588 
1589  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1590  // A -> B, B -> A ==> A
1591  if (def.getSource().getType() == getDest().getType())
1592  return def.getSource();
1593  }
1594  return {};
1595 }
1596 
1597 template <typename ToBufferOp>
1598 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1599  OpaqueProperties prop,
1600  RegionRange region,
1602  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1603  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1604  Type elemTp = nullptr;
1605  bool withStride = false;
1606  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1607  elemTp = stt.getPosType();
1608  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1609  std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1610  elemTp = stt.getCrdType();
1611  if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1612  withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1613  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1614  elemTp = stt.getElementType();
1615  }
1616 
1617  assert(elemTp && "unhandled operation.");
1618  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1619  bufShape.push_back(ShapedType::kDynamic);
1620 
1621  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1622  stt.getContext(), ShapedType::kDynamic,
1623  {ShapedType::kDynamic})
1624  : StridedLayoutAttr();
1625  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1626  return success();
1627 }
1628 
1629 LogicalResult ToPositionsOp::verify() {
1630  auto stt = getSparseTensorType(getTensor());
1631  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1632  return emitError("requested level is out of bounds");
1633  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1634  return emitError("unexpected type for positions");
1635  return success();
1636 }
1637 
1638 LogicalResult
1639 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1640  ValueRange ops, DictionaryAttr attr,
1641  OpaqueProperties prop, RegionRange region,
1643  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1644 }
1645 
1646 LogicalResult ToCoordinatesOp::verify() {
1647  auto stt = getSparseTensorType(getTensor());
1648  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1649  return emitError("requested level is out of bounds");
1650  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1651  return emitError("unexpected type for coordinates");
1652  return success();
1653 }
1654 
1655 LogicalResult
1656 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1657  ValueRange ops, DictionaryAttr attr,
1658  OpaqueProperties prop, RegionRange region,
1660  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1661 }
1662 
1663 LogicalResult ToCoordinatesBufferOp::verify() {
1664  auto stt = getSparseTensorType(getTensor());
1665  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1666  return emitError("expected sparse tensor with a COO region");
1667  return success();
1668 }
1669 
1670 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1671  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1672  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1674  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1675  ret);
1676 }
1677 
1678 LogicalResult ToValuesOp::verify() {
1679  auto stt = getSparseTensorType(getTensor());
1680  auto mtp = getMemRefType(getResult());
1681  if (stt.getElementType() != mtp.getElementType())
1682  return emitError("unexpected mismatch in element types");
1683  return success();
1684 }
1685 
1686 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1687  std::optional<Location> loc,
1688  ValueRange ops, DictionaryAttr attr,
1689  OpaqueProperties prop,
1690  RegionRange region,
1692  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1693 }
1694 
1695 LogicalResult ToSliceOffsetOp::verify() {
1696  auto rank = getSlice().getType().getRank();
1697  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1698  return emitError("requested dimension out of bound");
1699  return success();
1700 }
1701 
1702 LogicalResult ToSliceStrideOp::verify() {
1703  auto rank = getSlice().getType().getRank();
1704  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1705  return emitError("requested dimension out of bound");
1706  return success();
1707 }
1708 
1709 LogicalResult GetStorageSpecifierOp::verify() {
1710  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1711  getSpecifier(), getOperation());
1712 }
1713 
1714 template <typename SpecifierOp>
1715 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1716  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1717 }
1718 
1719 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1720  const StorageSpecifierKind kind = getSpecifierKind();
1721  const auto lvl = getLevel();
1722  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1723  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1724  return op.getValue();
1725  return {};
1726 }
1727 
1728 LogicalResult SetStorageSpecifierOp::verify() {
1729  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1730  getSpecifier(), getOperation());
1731 }
1732 
1733 template <class T>
1734 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1735  const char *regionName,
1736  TypeRange inputTypes, Type outputType) {
1737  unsigned numArgs = region.getNumArguments();
1738  unsigned expectedNum = inputTypes.size();
1739  if (numArgs != expectedNum)
1740  return op->emitError() << regionName << " region must have exactly "
1741  << expectedNum << " arguments";
1742 
1743  for (unsigned i = 0; i < numArgs; i++) {
1744  Type typ = region.getArgument(i).getType();
1745  if (typ != inputTypes[i])
1746  return op->emitError() << regionName << " region argument " << (i + 1)
1747  << " type mismatch";
1748  }
1749  Operation *term = region.front().getTerminator();
1750  YieldOp yield = dyn_cast<YieldOp>(term);
1751  if (!yield)
1752  return op->emitError() << regionName
1753  << " region must end with sparse_tensor.yield";
1754  if (!yield.hasSingleResult() ||
1755  yield.getSingleResult().getType() != outputType)
1756  return op->emitError() << regionName << " region yield type mismatch";
1757 
1758  return success();
1759 }
1760 
1761 LogicalResult BinaryOp::verify() {
1762  NamedAttrList attrs = (*this)->getAttrs();
1763  Type leftType = getX().getType();
1764  Type rightType = getY().getType();
1765  Type outputType = getOutput().getType();
1766  Region &overlap = getOverlapRegion();
1767  Region &left = getLeftRegion();
1768  Region &right = getRightRegion();
1769 
1770  // Check correct number of block arguments and return type for each
1771  // non-empty region.
1772  if (!overlap.empty()) {
1773  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1774  TypeRange{leftType, rightType}, outputType)))
1775  return failure();
1776  }
1777  if (!left.empty()) {
1778  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1779  outputType)))
1780  return failure();
1781  } else if (getLeftIdentity()) {
1782  if (leftType != outputType)
1783  return emitError("left=identity requires first argument to have the same "
1784  "type as the output");
1785  }
1786  if (!right.empty()) {
1787  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1788  outputType)))
1789  return failure();
1790  } else if (getRightIdentity()) {
1791  if (rightType != outputType)
1792  return emitError("right=identity requires second argument to have the "
1793  "same type as the output");
1794  }
1795  return success();
1796 }
1797 
1798 LogicalResult UnaryOp::verify() {
1799  Type inputType = getX().getType();
1800  Type outputType = getOutput().getType();
1801 
1802  // Check correct number of block arguments and return type for each
1803  // non-empty region.
1804  Region &present = getPresentRegion();
1805  if (!present.empty()) {
1806  if (failed(verifyNumBlockArgs(this, present, "present",
1807  TypeRange{inputType}, outputType)))
1808  return failure();
1809  }
1810  Region &absent = getAbsentRegion();
1811  if (!absent.empty()) {
1812  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1813  outputType)))
1814  return failure();
1815  // Absent branch can only yield invariant values.
1816  Block *absentBlock = &absent.front();
1817  Block *parent = getOperation()->getBlock();
1818  Value absentVal =
1819  cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1820  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1821  if (arg.getOwner() == parent)
1822  return emitError("absent region cannot yield linalg argument");
1823  } else if (Operation *def = absentVal.getDefiningOp()) {
1824  if (!isa<arith::ConstantOp>(def) &&
1825  (def->getBlock() == absentBlock || def->getBlock() == parent))
1826  return emitError("absent region cannot yield locally computed value");
1827  }
1828  }
1829  return success();
1830 }
1831 
1832 bool ConcatenateOp::needsExtraSort() {
1833  SparseTensorType dstStt = getSparseTensorType(*this);
1834  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1835  return false;
1836 
1837  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1838  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1839  });
1840  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1841  // in all input/output buffers, and all input/output buffers have the same
1842  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1843  // CSC matrices along column).
1844  bool directLowerable =
1845  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1846  return !directLowerable;
1847 }
1848 
1849 LogicalResult ConcatenateOp::verify() {
1850  const auto dstTp = getSparseTensorType(*this);
1851  const Dimension concatDim = getDimension();
1852  const Dimension dimRank = dstTp.getDimRank();
1853 
1854  if (getInputs().size() <= 1)
1855  return emitError("Need at least two tensors to concatenate.");
1856 
1857  if (concatDim >= dimRank)
1858  return emitError(llvm::formatv(
1859  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1860  concatDim, dimRank));
1861 
1862  for (const auto &it : llvm::enumerate(getInputs())) {
1863  const auto i = it.index();
1864  const auto srcTp = getSparseTensorType(it.value());
1865  if (srcTp.hasDynamicDimShape())
1866  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1867  const Dimension srcDimRank = srcTp.getDimRank();
1868  if (srcDimRank != dimRank)
1869  return emitError(
1870  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1871  "from the output tensor (rank={2}).",
1872  i, srcDimRank, dimRank));
1873  }
1874 
1875  for (Dimension d = 0; d < dimRank; d++) {
1876  const Size dstSh = dstTp.getDimShape()[d];
1877  if (d == concatDim) {
1878  if (ShapedType::isStatic(dstSh)) {
1879  // If we reach here, then all inputs have static shapes. So we
1880  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1881  // to avoid redundant assertions in the loop.
1882  Size sumSz = 0;
1883  for (const auto src : getInputs())
1884  sumSz += getSparseTensorType(src).getDimShape()[d];
1885  // If all dimension are statically known, the sum of all the input
1886  // dimensions should be equal to the output dimension.
1887  if (sumSz != dstSh)
1888  return emitError(
1889  "The concatenation dimension of the output tensor should be the "
1890  "sum of all the concatenation dimensions of the input tensors.");
1891  }
1892  } else {
1893  Size prev = dstSh;
1894  for (const auto src : getInputs()) {
1895  const auto sh = getSparseTensorType(src).getDimShape()[d];
1896  if (ShapedType::isStatic(prev) && sh != prev)
1897  return emitError("All dimensions (expect for the concatenating one) "
1898  "should be equal.");
1899  prev = sh;
1900  }
1901  }
1902  }
1903 
1904  return success();
1905 }
1906 
1907 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1908  Value curSize, Value inBuffer, Value value) {
1909  build(builder, result, curSize, inBuffer, value, Value());
1910 }
1911 
1912 LogicalResult PushBackOp::verify() {
1913  if (Value n = getN()) {
1914  std::optional<int64_t> nValue = getConstantIntValue(n);
1915  if (nValue && nValue.value() < 1)
1916  return emitOpError("n must be not less than 1");
1917  }
1918  return success();
1919 }
1920 
1921 LogicalResult CompressOp::verify() {
1922  const auto stt = getSparseTensorType(getTensor());
1923  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1924  return emitOpError("incorrect number of coordinates");
1925  return success();
1926 }
1927 
1928 void ForeachOp::build(
1929  OpBuilder &builder, OperationState &result, Value tensor,
1930  ValueRange initArgs, AffineMapAttr order,
1932  bodyBuilder) {
1933  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1934  // Builds foreach body.
1935  if (!bodyBuilder)
1936  return;
1937  const auto stt = getSparseTensorType(tensor);
1938  const Dimension dimRank = stt.getDimRank();
1939 
1940  // Starts with `dimRank`-many coordinates.
1941  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1942  // Followed by one value.
1943  blockArgTypes.push_back(stt.getElementType());
1944  // Followed by the reduction variables.
1945  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1946 
1947  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1948 
1949  OpBuilder::InsertionGuard guard(builder);
1950  auto &region = *result.regions.front();
1951  Block *bodyBlock =
1952  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1953  bodyBuilder(builder, result.location,
1954  bodyBlock->getArguments().slice(0, dimRank),
1955  bodyBlock->getArguments()[dimRank],
1956  bodyBlock->getArguments().drop_front(dimRank + 1));
1957 }
1958 
1959 LogicalResult ForeachOp::verify() {
1960  const auto t = getSparseTensorType(getTensor());
1961  const Dimension dimRank = t.getDimRank();
1962  const auto args = getBody()->getArguments();
1963 
1964  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1965  return emitError("Level traverse order does not match tensor's level rank");
1966 
1967  if (dimRank + 1 + getInitArgs().size() != args.size())
1968  return emitError("Unmatched number of arguments in the block");
1969 
1970  if (getNumResults() != getInitArgs().size())
1971  return emitError("Mismatch in number of init arguments and results");
1972 
1973  if (getResultTypes() != getInitArgs().getTypes())
1974  return emitError("Mismatch in types of init arguments and results");
1975 
1976  // Cannot mark this const, because the getters aren't.
1977  auto yield = cast<YieldOp>(getBody()->getTerminator());
1978  if (yield.getNumOperands() != getNumResults() ||
1979  yield.getOperands().getTypes() != getResultTypes())
1980  return emitError("Mismatch in types of yield values and results");
1981 
1982  const auto iTp = IndexType::get(getContext());
1983  for (Dimension d = 0; d < dimRank; d++)
1984  if (args[d].getType() != iTp)
1985  return emitError(
1986  llvm::formatv("Expecting Index type for argument at index {0}", d));
1987 
1988  const auto elemTp = t.getElementType();
1989  const auto valueTp = args[dimRank].getType();
1990  if (elemTp != valueTp)
1991  return emitError(
1992  llvm::formatv("Unmatched element type between input tensor and "
1993  "block argument, expected:{0}, got: {1}",
1994  elemTp, valueTp));
1995  return success();
1996 }
1997 
1998 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1999  if (getSparseTensorEncoding(getInputCoo().getType()) ==
2000  getSparseTensorEncoding(getResultCoo().getType()))
2001  return getInputCoo();
2002 
2003  return {};
2004 }
2005 
2006 LogicalResult ReorderCOOOp::verify() {
2007  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2008  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2009 
2010  if (!srcStt.isCOOType() || !dstStt.isCOOType())
2011  return emitError("Expected COO sparse tensors only");
2012 
2013  if (!srcStt.hasSameDimToLvl(dstStt))
2014  return emitError("Unmatched dim2lvl map between input and result COO");
2015 
2016  if (srcStt.getPosType() != dstStt.getPosType() ||
2017  srcStt.getCrdType() != dstStt.getCrdType() ||
2018  srcStt.getElementType() != dstStt.getElementType())
2019  return emitError("Unmatched storage format between input and result COO");
2020 
2021  return success();
2022 }
2023 
2024 LogicalResult ReduceOp::verify() {
2025  Type inputType = getX().getType();
2026  Region &formula = getRegion();
2027  return verifyNumBlockArgs(this, formula, "reduce",
2028  TypeRange{inputType, inputType}, inputType);
2029 }
2030 
2031 LogicalResult SelectOp::verify() {
2032  Builder b(getContext());
2033  Type inputType = getX().getType();
2034  Type boolType = b.getI1Type();
2035  Region &formula = getRegion();
2036  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2037  boolType);
2038 }
2039 
2040 LogicalResult SortOp::verify() {
2041  AffineMap xPerm = getPermMap();
2042  uint64_t nx = xPerm.getNumDims();
2043  if (nx < 1)
2044  return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2045 
2046  if (!xPerm.isPermutation())
2047  return emitError(
2048  llvm::formatv("Expected a permutation map, got {0}", xPerm));
2049 
2050  // We can't check the size of the buffers when n or buffer dimensions aren't
2051  // compile-time constants.
2052  std::optional<int64_t> cn = getConstantIntValue(getN());
2053  if (!cn)
2054  return success();
2055 
2056  // Verify dimensions.
2057  const auto checkDim = [&](Value v, Size minSize,
2058  const char *message) -> LogicalResult {
2059  const Size sh = getMemRefType(v).getShape()[0];
2060  if (ShapedType::isStatic(sh) && sh < minSize)
2061  return emitError(
2062  llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2063  return success();
2064  };
2065  uint64_t n = cn.value();
2066  uint64_t ny = 0;
2067  if (auto nyAttr = getNyAttr())
2068  ny = nyAttr.getInt();
2069  if (failed(checkDim(getXy(), n * (nx + ny),
2070  "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2071  return failure();
2072  for (Value opnd : getYs())
2073  if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2074  return failure();
2075 
2076  return success();
2077 }
2078 
2079 //===----------------------------------------------------------------------===//
2080 // Sparse Tensor Iteration Operations.
2081 //===----------------------------------------------------------------------===//
2082 
2083 IterSpaceType IteratorType::getIterSpaceType() const {
2084  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2085  getHiLvl());
2086 }
2087 
2088 IteratorType IterSpaceType::getIteratorType() const {
2089  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2090 }
2091 
2092 /// Parses a level range in the form "$lo `to` $hi"
2093 /// or simply "$lo" if $hi - $lo = 1
2094 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2095  Level &lvlHi) {
2096  if (parser.parseInteger(lvlLo))
2097  return failure();
2098 
2099  if (succeeded(parser.parseOptionalKeyword("to"))) {
2100  if (parser.parseInteger(lvlHi))
2101  return failure();
2102  } else {
2103  lvlHi = lvlLo + 1;
2104  }
2105 
2106  if (lvlHi <= lvlLo)
2107  return parser.emitError(parser.getNameLoc(),
2108  "expect larger level upper bound than lower bound");
2109 
2110  return success();
2111 }
2112 
2113 /// Parses a level range in the form "$lo `to` $hi"
2114 /// or simply "$lo" if $hi - $lo = 1
2115 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2116  IntegerAttr &lvlHiAttr) {
2117  Level lvlLo, lvlHi;
2118  if (parseLevelRange(parser, lvlLo, lvlHi))
2119  return failure();
2120 
2121  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2122  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2123  return success();
2124 }
2125 
2126 /// Prints a level range in the form "$lo `to` $hi"
2127 /// or simply "$lo" if $hi - $lo = 1
2128 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2129 
2130  if (lo + 1 == hi)
2131  p << lo;
2132  else
2133  p << lo << " to " << hi;
2134 }
2135 
2136 /// Prints a level range in the form "$lo `to` $hi"
2137 /// or simply "$lo" if $hi - $lo = 1
2138 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2139  IntegerAttr lvlHi) {
2140  unsigned lo = lvlLo.getValue().getZExtValue();
2141  unsigned hi = lvlHi.getValue().getZExtValue();
2142  printLevelRange(p, lo, hi);
2143 }
2144 
2145 /// Parses a list of `optional` defined list in the form of
2146 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2147 /// corresponding value is not defined (e.g., to represent an undefined
2148 /// coordinate in the sparse iteration space).
2149 static ParseResult parseOptionalDefinedList(
2150  OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2152  unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2154  unsigned cnt = 0;
2155  ParseResult crdList =
2156  parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2157  if (parser.parseOptionalKeyword("_")) {
2158  if (parser.parseArgument(definedArgs.emplace_back()))
2159  return failure();
2160  definedSet.set(cnt);
2161  }
2162  cnt += 1;
2163  return success();
2164  });
2165 
2166  if (cnt > maxCnt)
2167  return parser.emitError(parser.getNameLoc(),
2168  "parsed more value than expected.");
2169 
2170  if (failed(crdList)) {
2171  return parser.emitError(
2172  parser.getNameLoc(),
2173  "expecting SSA value or \"_\" for level coordinates");
2174  }
2175  assert(definedArgs.size() == definedSet.count());
2176  return success();
2177 }
2178 
2179 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2180  Block::BlockArgListType blocksArgs,
2181  I64BitSet definedSet) {
2182  if (definedSet.empty())
2183  return;
2184 
2185  for (unsigned i = 0; i < size; i++) {
2186  if (definedSet[i]) {
2187  p << blocksArgs.front();
2188  blocksArgs = blocksArgs.drop_front();
2189  } else {
2190  p << "_";
2191  }
2192  if (i != size - 1)
2193  p << ", ";
2194  }
2195  assert(blocksArgs.empty());
2196 }
2197 
2198 static ParseResult
2201  // Parse "at(%crd0, _, ...)"
2202  I64BitSet crdUsedLvlSet;
2203  if (succeeded(parser.parseOptionalKeyword("at")) &&
2204  failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2205  return failure();
2206 
2207  // Always use IndexType for the coordinate.
2208  for (auto &coord : coords)
2209  coord.type = parser.getBuilder().getIndexType();
2210 
2211  // Set the CrdUsedLvl bitset.
2212  state.addAttribute("crdUsedLvls",
2213  parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2214  return success();
2215 }
2216 
2217 static ParseResult
2223 
2224  // Parse "%iters, ... in %spaces, ..."
2225  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2226  parser.parseOperandList(spaces))
2227  return failure();
2228 
2229  if (iterators.size() != spaces.size())
2230  return parser.emitError(
2231  parser.getNameLoc(),
2232  "mismatch in number of sparse iterators and sparse spaces");
2233 
2235  if (failed(parseUsedCoordList(parser, state, coords)))
2236  return failure();
2237  size_t numCrds = coords.size();
2238 
2239  // Parse "iter_args(%arg = %init, ...)"
2240  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2241  if (hasIterArgs)
2242  if (parser.parseAssignmentList(blockArgs, initArgs))
2243  return failure();
2244 
2245  blockArgs.append(coords);
2246 
2247  SmallVector<Type> iterSpaceTps;
2248  // parse ": sparse_tensor.iter_space -> ret"
2249  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2250  return failure();
2251  if (iterSpaceTps.size() != spaces.size())
2252  return parser.emitError(parser.getNameLoc(),
2253  "mismatch in number of iteration space operands "
2254  "and iteration space types");
2255 
2256  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2257  IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2258  if (!spaceTp)
2259  return parser.emitError(parser.getNameLoc(),
2260  "expected sparse_tensor.iter_space type for "
2261  "iteration space operands");
2262  it.type = spaceTp.getIteratorType();
2263  }
2264 
2265  if (hasIterArgs)
2266  if (parser.parseArrowTypeList(state.types))
2267  return failure();
2268 
2269  // Resolves input operands.
2270  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2271  state.operands))
2272  return failure();
2273 
2274  if (hasIterArgs) {
2275  // Strip off leading args that used for coordinates.
2276  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2277  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2278  return parser.emitError(
2279  parser.getNameLoc(),
2280  "mismatch in number of iteration arguments and return values");
2281  }
2282 
2283  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2284  it.type = tp;
2285  if (parser.resolveOperand(init, tp, state.operands))
2286  return failure();
2287  }
2288  }
2289  return success();
2290 }
2291 
2292 static ParseResult
2294  SmallVectorImpl<Value> &spacesVals,
2296 
2297  // Parse "(%spaces, ...)"
2299  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2300  return failure();
2301 
2303  if (failed(parseUsedCoordList(parser, state, coords)))
2304  return failure();
2305  size_t numCrds = coords.size();
2306 
2307  // Parse "iter_args(%arg = %init, ...)"
2309  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2310  if (hasIterArgs)
2311  if (parser.parseAssignmentList(blockArgs, initArgs))
2312  return failure();
2313  blockArgs.append(coords);
2314 
2315  SmallVector<Type> iterSpaceTps;
2316  // parse ": (sparse_tensor.iter_space, ...) -> ret"
2317  if (parser.parseColon() || parser.parseLParen() ||
2318  parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2319  return failure();
2320 
2321  if (iterSpaceTps.size() != spaces.size())
2322  return parser.emitError(parser.getNameLoc(),
2323  "mismatch in number of iteration space operands "
2324  "and iteration space types");
2325 
2326  if (hasIterArgs)
2327  if (parser.parseArrowTypeList(state.types))
2328  return failure();
2329 
2330  // Resolves input sparse iteration spaces.
2331  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2332  spacesVals))
2333  return failure();
2334  state.operands.append(spacesVals);
2335 
2336  if (hasIterArgs) {
2337  // Strip off trailing args that used for coordinates.
2338  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2339  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2340  return parser.emitError(
2341  parser.getNameLoc(),
2342  "mismatch in number of iteration arguments and return values");
2343  }
2344 
2345  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2346  it.type = tp;
2347  if (parser.resolveOperand(init, tp, state.operands))
2348  return failure();
2349  }
2350  }
2351  return success();
2352 }
2353 
2354 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2355  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2356  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2358 
2359  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2360  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2361  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2362  adaptor.getHiLvl()));
2363  return success();
2364 }
2365 
2366 LogicalResult ExtractIterSpaceOp::verify() {
2367  if (getLoLvl() >= getHiLvl())
2368  return emitOpError("expected smaller level low than level high");
2369 
2370  TypedValue<IteratorType> pIter = getParentIter();
2371  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2372  return emitOpError(
2373  "parent iterator should be specified iff level lower bound equals 0");
2374  }
2375 
2376  if (pIter) {
2377  IterSpaceType spaceTp = getExtractedSpace().getType();
2378  if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2379  return emitOpError(
2380  "mismatch in parent iterator encoding and iteration space encoding.");
2381 
2382  if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2383  return emitOpError("parent iterator should be used to extract an "
2384  "iteration space from a consecutive level.");
2385  }
2386 
2387  return success();
2388 }
2389 
2390 LogicalResult ExtractValOp::verify() {
2391  auto stt = getSparseTensorType(getTensor());
2392  auto itTp = getIterator().getType();
2393 
2394  if (stt.getEncoding() != itTp.getEncoding())
2395  return emitOpError("mismatch in tensor encoding and iterator encoding.");
2396 
2397  if (stt.getLvlRank() != itTp.getHiLvl())
2398  return emitOpError("must use last-level iterator to extract values. ");
2399 
2400  return success();
2401 }
2402 
2403 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2405 
2406  LogicalResult matchAndRewrite(IterateOp iterateOp,
2407  PatternRewriter &rewriter) const override {
2408  I64BitSet newUsedLvls(0);
2409  llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2410  for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2411  if (auto crd = iterateOp.getLvlCrd(i)) {
2412  if (crd->getUsers().empty())
2413  toRemove.set(crd->getArgNumber());
2414  else
2415  newUsedLvls.set(i);
2416  }
2417  }
2418 
2419  // All coordinates are used.
2420  if (toRemove.none())
2421  return failure();
2422 
2423  rewriter.startOpModification(iterateOp);
2424  iterateOp.setCrdUsedLvls(newUsedLvls);
2425  iterateOp.getBody()->eraseArguments(toRemove);
2426  rewriter.finalizeOpModification(iterateOp);
2427  return success();
2428  }
2429 };
2430 
2431 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2432  mlir::MLIRContext *context) {
2433  results.add<RemoveUnusedLvlCrds>(context);
2434 }
2435 
2436 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2437  Value iterSpace, ValueRange initArgs) {
2438  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2439  // All ones.
2440  I64BitSet set((1 << rank) - 1);
2441  return build(builder, odsState, iterSpace, initArgs, set);
2442 }
2443 
2444 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2445  Value iterSpace, ValueRange initArgs,
2446  I64BitSet crdUsedLvls) {
2447  OpBuilder::InsertionGuard guard(builder);
2448 
2449  odsState.addOperands(iterSpace);
2450  odsState.addOperands(initArgs);
2451  odsState.getOrAddProperties<Properties>().crdUsedLvls =
2452  builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2453  Region *bodyRegion = odsState.addRegion();
2454  odsState.addTypes(initArgs.getTypes());
2455  Block *bodyBlock = builder.createBlock(bodyRegion);
2456 
2457  // Starts with a list of user-provided loop arguments.
2458  for (Value v : initArgs)
2459  bodyBlock->addArgument(v.getType(), v.getLoc());
2460 
2461  // Follows by a list of used coordinates.
2462  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2463  bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2464 
2465  // Ends with sparse iterator
2466  bodyBlock->addArgument(
2467  llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2468  odsState.location);
2469 }
2470 
2471 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2472  OpAsmParser::Argument iterator;
2474 
2475  SmallVector<OpAsmParser::Argument> iters, iterArgs;
2476  if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2477  return failure();
2478  if (iters.size() != 1)
2479  return parser.emitError(parser.getNameLoc(),
2480  "expected only one iterator/iteration space");
2481 
2482  iterArgs.append(iters);
2483  Region *body = result.addRegion();
2484  if (parser.parseRegion(*body, iterArgs))
2485  return failure();
2486 
2487  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2488 
2489  // Parse the optional attribute list.
2490  if (parser.parseOptionalAttrDict(result.attributes))
2491  return failure();
2492 
2493  return success();
2494 }
2495 
2496 /// Prints the initialization list in the form of
2497 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2498 /// where 'inner' values are assumed to be region arguments and 'outer' values
2499 /// are regular SSA values.
2501  Block::BlockArgListType blocksArgs,
2502  ValueRange initializers,
2503  StringRef prefix = "") {
2504  assert(blocksArgs.size() == initializers.size() &&
2505  "expected same length of arguments and initializers");
2506  if (initializers.empty())
2507  return;
2508 
2509  p << prefix << '(';
2510  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2511  p << std::get<0>(it) << " = " << std::get<1>(it);
2512  });
2513  p << ")";
2514 }
2515 
2516 template <typename SparseLoopOp>
2517 static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2518  if (op.getInitArgs().size() != op.getNumResults()) {
2519  return op.emitOpError(
2520  "mismatch in number of loop-carried values and defined values");
2521  }
2522  if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2523  return op.emitOpError("required out-of-bound coordinates");
2524 
2525  return success();
2526 }
2527 
2528 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2529 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2530 
2531 void IterateOp::print(OpAsmPrinter &p) {
2532  p << " " << getIterator() << " in " << getIterSpace();
2533  if (!getCrdUsedLvls().empty()) {
2534  p << " at(";
2535  printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2536  p << ")";
2537  }
2538  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2539 
2540  p << " : " << getIterSpace().getType() << " ";
2541  if (!getInitArgs().empty())
2542  p.printArrowTypeList(getInitArgs().getTypes());
2543 
2544  p << " ";
2545  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2546  /*printBlockTerminators=*/!getInitArgs().empty());
2547 }
2548 
2549 LogicalResult IterateOp::verifyRegions() {
2550  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2551  return emitOpError("mismatch in iterator and iteration space type");
2552  if (getNumRegionIterArgs() != getNumResults())
2553  return emitOpError(
2554  "mismatch in number of basic block args and defined values");
2555 
2556  auto initArgs = getInitArgs();
2557  auto iterArgs = getRegionIterArgs();
2558  auto yieldVals = getYieldedValues();
2559  auto opResults = getResults();
2560  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2561  opResults.size()})) {
2562  return emitOpError() << "number mismatch between iter args and results.";
2563  }
2564 
2565  for (auto [i, init, iter, yield, ret] :
2566  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2567  if (init.getType() != ret.getType())
2568  return emitOpError() << "types mismatch between " << i
2569  << "th iter operand and defined value";
2570  if (iter.getType() != ret.getType())
2571  return emitOpError() << "types mismatch between " << i
2572  << "th iter region arg and defined value";
2573  if (yield.getType() != ret.getType())
2574  return emitOpError() << "types mismatch between " << i
2575  << "th yield value and defined value";
2576  }
2577 
2578  return success();
2579 }
2580 
2581 /// OpInterfaces' methods implemented by IterateOp.
2582 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2583 
2584 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2585  return getInitArgsMutable();
2586 }
2587 
2588 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2589  return getRegion().getArguments().take_front(getNumRegionIterArgs());
2590 }
2591 
2592 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2593  return cast<sparse_tensor::YieldOp>(
2594  getRegion().getBlocks().front().getTerminator())
2595  .getResultsMutable();
2596 }
2597 
2598 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2599 
2600 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2601  return getInitArgs();
2602 }
2603 
2604 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2606  // Both the operation itself and the region may be branching into the body
2607  // or back into the operation itself.
2608  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2609  // It is possible for loop not to enter the body.
2610  regions.push_back(RegionSuccessor(getResults()));
2611 }
2612 
2613 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2614  ValueRange iterSpaces, ValueRange initArgs,
2615  unsigned numCases) {
2616  unsigned rank =
2617  cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2618  // All ones.
2619  I64BitSet set((1 << rank) - 1);
2620  // Generates all-zero case bits (they only serve as placeholders), which are
2621  // supposed to be overriden later. We need to preallocate all the regions as
2622  // mlir::Region cannot be dynamically added later after the operation is
2623  // created.
2624  SmallVector<int64_t> caseBits(numCases, 0);
2625  ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2626  return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2627  initArgs, set, cases,
2628  /*caseRegionsCount=*/numCases);
2629 }
2630 
2631 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2632 
2633  SmallVector<Value> spaces;
2634  // The block argument list of each regions, it is arranged in the order of
2635  // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2637  if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2638  return failure();
2639 
2640  result.addAttribute("operandSegmentSizes",
2642  {static_cast<int32_t>(spaces.size()),
2643  static_cast<int32_t>(result.types.size())}));
2644 
2645  SmallVector<Attribute> cases;
2646  while (succeeded(parser.parseOptionalKeyword("case"))) {
2647  // Parse one region per case.
2648  I64BitSet definedItSet;
2650  if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2651  spaces.size(), OpAsmParser::Delimiter::None))
2652  return failure();
2653 
2654  cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2655 
2656  for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2657  // Resolve the iterator type based on the iteration space type.
2658  auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2659  definedIts[i].type = spaceTp.getIteratorType();
2660  }
2661  definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2662  Region *body = result.addRegion();
2663  if (parser.parseRegion(*body, definedIts))
2664  return failure();
2665 
2666  CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2667  }
2668 
2669  result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2670 
2671  // Parse the optional attribute list.
2672  if (parser.parseOptionalAttrDict(result.attributes))
2673  return failure();
2674 
2675  return success();
2676 }
2677 
2679  p << " (";
2680  llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2681  p << ")";
2682 
2683  if (!getCrdUsedLvls().empty()) {
2684  p << " at(";
2685  printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2686  p << ")";
2687  }
2688 
2689  printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2690 
2691  p << " : (" << getIterSpaces().getTypes() << ")";
2692  if (!getInitArgs().empty())
2693  p.printArrowTypeList(getInitArgs().getTypes());
2694 
2695  for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2696  p.printNewline();
2697  p << "case ";
2698  printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2699  getRegionDefinedSpace(idx));
2700  p << " ";
2701  p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2702  /*printBlockTerminators=*/!getInitArgs().empty());
2703  }
2704 }
2705 
2706 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2707  return cast<sparse_tensor::YieldOp>(
2708  getRegion(regionIdx).getBlocks().front().getTerminator())
2709  .getResults();
2710 }
2711 
2712 LogicalResult CoIterateOp::verifyRegions() {
2713  for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2714  if (getNumRegionIterArgs() != getNumResults())
2715  return emitOpError(
2716  "mismatch in number of basic block args and defined values");
2717 
2718  auto initArgs = getInitArgs();
2719  auto iterArgs = getRegionIterArgs(r);
2720  auto yieldVals = getYieldedValues(r);
2721  auto opResults = getResults();
2722  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2723  opResults.size()})) {
2724  return emitOpError()
2725  << "number mismatch between iter args and results on " << r
2726  << "th region";
2727  }
2728 
2729  for (auto [i, init, iter, yield, ret] :
2730  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2731  if (init.getType() != ret.getType())
2732  return emitOpError()
2733  << "types mismatch between " << i
2734  << "th iter operand and defined value on " << r << "th region";
2735  if (iter.getType() != ret.getType())
2736  return emitOpError() << "types mismatch between " << i
2737  << "th iter region arg and defined value on " << r
2738  << "th region";
2739  if (yield.getType() != ret.getType())
2740  return emitOpError()
2741  << "types mismatch between " << i
2742  << "th yield value and defined value on " << r << "th region";
2743  }
2744  }
2745 
2746  auto cases = getRegionDefinedSpaces();
2747  llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2748  if (set.size() != getNumRegions())
2749  return emitOpError("contains duplicated cases.");
2750 
2751  return success();
2752 }
2753 
2754 SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2756  I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2757  for (Region &r : getCaseRegions())
2758  if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2759  ret.push_back(&r);
2760 
2761  return ret;
2762 }
2763 
2764 //===----------------------------------------------------------------------===//
2765 // Sparse Tensor Dialect Setups.
2766 //===----------------------------------------------------------------------===//
2767 
2768 /// Materialize a single constant operation from a given attribute value with
2769 /// the desired resultant type.
2771  Attribute value, Type type,
2772  Location loc) {
2773  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2774  return op;
2775  return nullptr;
2776 }
2777 
2778 void SparseTensorDialect::initialize() {
2779  addAttributes<
2780 #define GET_ATTRDEF_LIST
2781 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2782  >();
2783  addTypes<
2784 #define GET_TYPEDEF_LIST
2785 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2786  >();
2787  addOperations<
2788 #define GET_OP_LIST
2789 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2790  >();
2791  declarePromisedInterfaces<
2792  bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2793  NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2794  ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2795 }
2796 
2797 #define GET_OP_CLASSES
2798 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2799 
2800 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition: IRAffine.cpp:67
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool isUnique(It begin, It end)
Definition: ShardOps.cpp:161
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:365
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
IndexType getIndexType()
Definition: Builders.cpp:50
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
bool isOrderedLT(LevelType lt)
Definition: Enums.h:425
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:442
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
llvm::hash_code hash_value(LevelType lt)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:402
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:68
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
Definition: SparseTensor.h:50
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299