MLIR  20.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53  return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62  switch (bitWidth) {
63  case 0:
64  case 8:
65  case 16:
66  case 32:
67  case 64:
68  return true;
69  default:
70  return false;
71  }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76  std::optional<ArrayRef<int64_t>> dimShape) {
77  assert(enc);
78  // With only encoding, we can not determine the static shape for leading
79  // batch levels, we therefore return a dynamic shape memref instead.
80  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81  if (dimShape.has_value()) {
82  // If the actual tensor shape is provided, we can then refine the leading
83  // batch dimension.
84  SmallVector<int64_t> lvlShape =
85  enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86  memrefShape.assign(lvlShape.begin(),
87  lvlShape.begin() + enc.getBatchLvlRank());
88  }
89  // Another dynamic dimension to store the sparse level.
90  memrefShape.push_back(ShapedType::kDynamic);
91  return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
104  LevelType)>
105  callback) const {
106  const auto lvlTypes = enc.getLvlTypes();
107  const Level lvlRank = enc.getLvlRank();
108  SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
110 
111  ArrayRef cooSegsRef = cooSegs;
112  // Per-level storage.
113  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114  const auto lt = lvlTypes[l];
115  if (isWithPosLT(lt)) {
116  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117  return;
118  }
119  if (isWithCrdLT(lt)) {
120  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121  return;
122  }
123  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124  if (!cooSegsRef.front().isSoA) {
125  // AoS COO, all singletons are fused into one memrefs. Skips the entire
126  // COO segement.
127  l = cooSegsRef.front().lvlRange.second;
128  } else {
129  // SoA COO, each singleton level has one memref.
130  l++;
131  }
132  // Expire handled COO segment.
133  cooSegsRef = cooSegsRef.drop_front();
134  } else {
135  // Non COO levels.
136  l++;
137  }
138  }
139  // The values array.
140  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
142  return;
143  // Put metadata at the end.
144  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
146  return;
147 }
148 
150  SparseTensorType stt,
152  LevelType)>
153  callback) {
154  assert(stt.hasEncoding());
155 
156  SmallVector<int64_t> memrefShape =
158 
159  const Type specType = StorageSpecifierType::get(stt.getEncoding());
160  // memref<[batch] x ? x pos> positions
161  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162  // memref<[batch] x ? x crd> coordinates
163  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164  // memref<[batch] x ? x eltType> values
165  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168  callback](FieldIndex fieldIdx,
169  SparseTensorFieldKind fieldKind,
170  Level lvl, LevelType lt) -> bool {
171  switch (fieldKind) {
173  return callback(specType, fieldIdx, fieldKind, lvl, lt);
175  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
177  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
179  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180  };
181  llvm_unreachable("unrecognized field kind");
182  });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186  unsigned numFields = 0;
188  LevelType) -> bool {
189  numFields++;
190  return true;
191  });
192  return numFields;
193 }
194 
196  unsigned numFields = 0; // one value memref
198  LevelType) -> bool {
199  if (fidx >= kDataFieldStartingIdx)
200  numFields++;
201  return true;
202  });
203  numFields -= 1; // the last field is StorageSpecifier
204  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205  return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
210  std::optional<Level> lvl) const {
211  FieldIndex fieldIdx = kInvalidFieldIndex;
212  unsigned stride = 1;
213  if (kind == SparseTensorFieldKind::CrdMemRef) {
214  assert(lvl.has_value());
215  const Level cooStart = enc.getAoSCOOStart();
216  const Level lvlRank = enc.getLvlRank();
217  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218  lvl = cooStart;
219  stride = lvlRank - cooStart;
220  }
221  }
222  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223  SparseTensorFieldKind fKind, Level fLvl,
224  LevelType lt) -> bool {
225  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227  fieldIdx = fIdx;
228  // Returns false to break the iteration.
229  return false;
230  }
231  return true;
232  });
233  assert(fieldIdx != kInvalidFieldIndex);
234  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242  return isDynamic(v) ? std::nullopt
243  : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247  return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251  return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255  return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259  return isDynamic(getOffset()) && isDynamic(getStride()) &&
260  isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264  return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269  os << '(';
270  os << getStaticString(getOffset());
271  os << ", ";
272  os << getStaticString(getSize());
273  os << ", ";
274  os << getStaticString(getStride());
275  os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279  print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283  AsmParser &parser) {
284  auto parseResult = parser.parseOptionalInteger(result);
285  if (parseResult.has_value()) {
286  if (parseResult.value().succeeded() && result < 0) {
287  parser.emitError(
288  parser.getCurrentLocation(),
289  "expect positive value or ? for slice offset/size/stride");
290  return failure();
291  }
292  return parseResult.value();
293  }
294 
295  // Else, and '?' which represented dynamic slice
296  result = SparseTensorDimSliceAttr::kDynamic;
297  return parser.parseQuestion();
298 }
299 
301  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303  if (failed(parser.parseLParen()) ||
304  failed(parseOptionalStaticSlice(offset, parser)) ||
305  failed(parser.parseComma()) ||
306  failed(parseOptionalStaticSlice(size, parser)) ||
307  failed(parser.parseComma()) ||
308  failed(parseOptionalStaticSlice(stride, parser)) ||
309  failed(parser.parseRParen()))
310  return {};
311 
312  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313  offset, size, stride);
314 }
315 
316 LogicalResult
318  int64_t offset, int64_t size, int64_t stride) {
319  if (!isDynamic(offset) && offset < 0)
320  return emitError() << "expect non-negative value or ? for slice offset";
321  if (!isDynamic(size) && size <= 0)
322  return emitError() << "expect positive value or ? for slice size";
323  if (!isDynamic(stride) && stride <= 0)
324  return emitError() << "expect positive value or ? for slice stride";
325  return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
332  getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333  getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342  return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347  unsigned crdWidth) const {
348  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
350  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351  crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355  return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
362  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363  getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367  return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
374  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375  getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379  return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
385  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386  getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394  ArrayRef<LevelType> lvlTypes = getLvlTypes();
395  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396  return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
400  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408  if (!getImpl())
409  return nullptr;
410  if (getCrdWidth())
411  return IntegerType::get(getContext(), getCrdWidth());
412  return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416  if (!getImpl())
417  return nullptr;
418  if (getPosWidth())
419  return IntegerType::get(getContext(), getPosWidth());
420  return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424  std::optional<ArrayRef<int64_t>> dimShape) const {
425  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426  return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430  std::optional<ArrayRef<int64_t>> dimShape) const {
431  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432  return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
440  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445  const auto dimToLvl = getDimToLvl();
446  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451  return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455  if (!getImpl())
456  return LevelFormat::Batch;
457  assert(l < getLvlRank() && "Level is out of bounds");
458  return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463  return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468  assert(isSlice() && "Is not a slice");
469  const auto dimSlices = getDimSlices();
470  assert(dim < dimSlices.size() && "Dimension is out of bounds");
471  return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476  return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481  return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486  return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491  return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496  CrdTransDirectionKind dir) const {
497  if (isIdentity())
498  return SmallVector<int64_t>(srcShape);
499 
501  unsigned rank =
502  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503  ret.reserve(rank);
504 
505  if (isPermutation()) {
506  for (unsigned r = 0; r < rank; r++) {
507  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508  : toLvl(*this, r);
509  ret.push_back(srcShape[trans]);
510  }
511  return ret;
512  }
513 
514  // Handle non-permutation maps.
515  AffineMap transMap =
516  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
519  dimRep.reserve(srcShape.size());
520  for (int64_t sz : srcShape) {
521  if (!ShapedType::isDynamic(sz)) {
522  // Push back the max coordinate for the given dimension/level size.
523  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524  } else {
525  // A dynamic size, use a AffineDimExpr to symbolize the value.
526  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527  }
528  };
529 
530  for (AffineExpr exp : transMap.getResults()) {
531  // Do constant propagation on the affine map.
532  AffineExpr evalExp =
533  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534  // use llvm namespace here to avoid ambiguity
535  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536  ret.push_back(c.getValue() + 1);
537  } else {
538  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539  mod && mod.getKind() == AffineExprKind::Mod) {
540  // We can still infer a static bound for expressions in form
541  // "d % constant" since d % constant \in [0, constant).
542  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543  ret.push_back(bound.getValue());
544  continue;
545  }
546  }
547  ret.push_back(ShapedType::kDynamic);
548  }
549  }
550  assert(ret.size() == rank);
551  return ret;
552 }
553 
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556  ValueRange crds,
557  CrdTransDirectionKind dir) const {
558  if (!getImpl())
559  return crds;
560 
561  SmallVector<Type> retType(
562  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563  builder.getIndexType());
564  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565  return transOp.getOutCrds();
566 }
567 
569  // Open "<{" part.
570  if (failed(parser.parseLess()))
571  return {};
572  if (failed(parser.parseLBrace()))
573  return {};
574 
575  // Process the data from the parsed dictionary value into struct-like data.
576  SmallVector<LevelType> lvlTypes;
578  AffineMap dimToLvl = {};
579  AffineMap lvlToDim = {};
580  unsigned posWidth = 0;
581  unsigned crdWidth = 0;
582  Attribute explicitVal;
583  Attribute implicitVal;
584  StringRef attrName;
585  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586  "explicitVal", "implicitVal"};
587  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588  // Detect admissible keyword.
589  auto *it = find(keys, attrName);
590  if (it == keys.end()) {
591  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592  return {};
593  }
594  unsigned keyWordIndex = it - keys.begin();
595  // Consume the `=` after keys
596  if (failed(parser.parseEqual()))
597  return {};
598  // Dispatch on keyword.
599  switch (keyWordIndex) {
600  case 0: { // map
601  ir_detail::DimLvlMapParser cParser(parser);
602  auto res = cParser.parseDimLvlMap();
603  if (failed(res))
604  return {};
605  const auto &dlm = *res;
606 
607  const Level lvlRank = dlm.getLvlRank();
608  for (Level lvl = 0; lvl < lvlRank; lvl++)
609  lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611  const Dimension dimRank = dlm.getDimRank();
612  for (Dimension dim = 0; dim < dimRank; dim++)
613  dimSlices.push_back(dlm.getDimSlice(dim));
614  // NOTE: the old syntax requires an all-or-nothing approach to
615  // `dimSlices`; therefore, if any slice actually exists then we need
616  // to convert null-DSA into default/nop DSA.
617  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618  return static_cast<bool>(slice.getImpl());
619  };
620  if (llvm::any_of(dimSlices, isDefined)) {
621  const auto defaultSlice =
623  for (Dimension dim = 0; dim < dimRank; dim++)
624  if (!isDefined(dimSlices[dim]))
625  dimSlices[dim] = defaultSlice;
626  } else {
627  dimSlices.clear();
628  }
629 
630  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632  break;
633  }
634  case 1: { // posWidth
635  Attribute attr;
636  if (failed(parser.parseAttribute(attr)))
637  return {};
638  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639  if (!intAttr) {
640  parser.emitError(parser.getNameLoc(),
641  "expected an integral position bitwidth");
642  return {};
643  }
644  posWidth = intAttr.getInt();
645  break;
646  }
647  case 2: { // crdWidth
648  Attribute attr;
649  if (failed(parser.parseAttribute(attr)))
650  return {};
651  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652  if (!intAttr) {
653  parser.emitError(parser.getNameLoc(),
654  "expected an integral index bitwidth");
655  return {};
656  }
657  crdWidth = intAttr.getInt();
658  break;
659  }
660  case 3: { // explicitVal
661  Attribute attr;
662  if (failed(parser.parseAttribute(attr)))
663  return {};
664  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665  explicitVal = result;
666  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667  explicitVal = result;
668  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669  explicitVal = result;
670  } else {
671  parser.emitError(parser.getNameLoc(),
672  "expected a numeric value for explicitVal");
673  return {};
674  }
675  break;
676  }
677  case 4: { // implicitVal
678  Attribute attr;
679  if (failed(parser.parseAttribute(attr)))
680  return {};
681  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682  implicitVal = result;
683  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684  implicitVal = result;
685  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686  implicitVal = result;
687  } else {
688  parser.emitError(parser.getNameLoc(),
689  "expected a numeric value for implicitVal");
690  return {};
691  }
692  break;
693  }
694  } // switch
695  // Only last item can omit the comma.
696  if (parser.parseOptionalComma().failed())
697  break;
698  }
699 
700  // Close "}>" part.
701  if (failed(parser.parseRBrace()))
702  return {};
703  if (failed(parser.parseGreater()))
704  return {};
705 
706  // Construct struct-like storage for attribute.
707  if (!lvlToDim || lvlToDim.isEmpty()) {
708  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709  }
710  return parser.getChecked<SparseTensorEncodingAttr>(
711  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712  explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716  auto map = static_cast<AffineMap>(getDimToLvl());
717  // Empty affine map indicates identity map
718  if (!map)
719  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720  printer << "<{ map = ";
721  printSymbols(map, printer);
722  printer << '(';
723  printDimensions(map, printer, getDimSlices());
724  printer << ") -> (";
725  printLevels(map, printer, getLvlTypes());
726  printer << ')';
727  // Print remaining members only for non-default values.
728  if (getPosWidth())
729  printer << ", posWidth = " << getPosWidth();
730  if (getCrdWidth())
731  printer << ", crdWidth = " << getCrdWidth();
732  if (getExplicitVal()) {
733  printer << ", explicitVal = " << getExplicitVal();
734  }
735  if (getImplicitVal())
736  printer << ", implicitVal = " << getImplicitVal();
737  printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741  AsmPrinter &printer) const {
742  if (map.getNumSymbols() == 0)
743  return;
744  printer << '[';
745  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746  printer << 's' << i << ", ";
747  if (map.getNumSymbols() >= 1)
748  printer << 's' << map.getNumSymbols() - 1;
749  printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753  AffineMap &map, AsmPrinter &printer,
754  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755  if (!dimSlices.empty()) {
756  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757  printer << 'd' << i << " : " << dimSlices[i] << ", ";
758  if (map.getNumDims() >= 1) {
759  printer << 'd' << map.getNumDims() - 1 << " : "
760  << dimSlices[map.getNumDims() - 1];
761  }
762  } else {
763  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764  printer << 'd' << i << ", ";
765  if (map.getNumDims() >= 1)
766  printer << 'd' << map.getNumDims() - 1;
767  }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771  ArrayRef<LevelType> lvlTypes) const {
772  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773  map.getResult(i).print(printer.getStream());
774  printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775  }
776  if (map.getNumResults() >= 1) {
777  auto lastIndex = map.getNumResults() - 1;
778  map.getResult(lastIndex).print(printer.getStream());
779  printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780  }
781 }
782 
785  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786  unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
788  if (!acceptBitWidth(posWidth))
789  return emitError() << "unexpected position bitwidth: " << posWidth;
790  if (!acceptBitWidth(crdWidth))
791  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793  // Verify every COO segment.
794  auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
795  while (it != lvlTypes.end()) {
796  if (it == lvlTypes.begin() ||
798  return emitError() << "expected compressed or loose_compressed level "
799  "before singleton level";
800 
801  auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802  if (!std::all_of(it, curCOOEnd,
803  [](LevelType i) { return isSingletonLT(i); }))
804  return emitError() << "expected all singleton lvlTypes "
805  "following a singleton level";
806  // We can potentially support mixed SoA/AoS singleton levels.
807  if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808  return it->isa<LevelPropNonDefault::SoA>() ==
810  })) {
811  return emitError() << "expected all singleton lvlTypes stored in the "
812  "same memory layout (SoA vs AoS).";
813  }
814  it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815  }
816 
817  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
818  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
819  return emitError() << "Batch lvlType can only be leading levels.";
820 
821  // SoA property can only be applied on singleton level.
822  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823  return lt.isa<LevelPropNonDefault::SoA>();
824  });
825  if (llvm::any_of(soaLvls, [](LevelType lt) {
826  return !lt.isa<LevelFormat::Singleton>();
827  })) {
828  return emitError() << "SoA is only applicable to singleton lvlTypes.";
829  }
830 
831  // TODO: audit formats that actually are supported by backend.
832  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
833  it != std::end(lvlTypes)) {
834  if (it != lvlTypes.end() - 1)
835  return emitError() << "expected n_out_of_m to be the last level type";
836  if (!std::all_of(lvlTypes.begin(), it,
837  [](LevelType i) { return isDenseLT(i); }))
838  return emitError() << "expected all dense lvlTypes "
839  "before a n_out_of_m level";
840  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
841  if (!isBlockSparsity(dimToLvl)) {
842  return emitError()
843  << "expected 1xm block structure for n_out_of_m level";
844  }
845  auto sizes = getBlockSize(dimToLvl);
846  unsigned coefficient = 0;
847  for (const auto &elem : sizes) {
848  if (elem != 0) {
849  if (elem != coefficient && coefficient != 0) {
850  return emitError() << "expected only one blocked level "
851  "with the same coefficients";
852  }
853  coefficient = elem;
854  }
855  }
856  if (coefficient != getM(*it)) {
857  return emitError() << "expected coeffiencts of Affine expressions "
858  "to be equal to m of n_out_of_m level";
859  }
860  }
861  }
862  // Before we can check that the level-rank is consistent/coherent
863  // across all fields, we need to define it. The source-of-truth for
864  // the `getLvlRank` method is the length of the level-types array,
865  // since it must always be provided and have full rank; therefore we
866  // use that same source-of-truth here.
867  const Level lvlRank = lvlTypes.size();
868  if (lvlRank == 0)
869  return emitError() << "expected a non-empty array for lvlTypes";
870  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
871  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
872  if (dimToLvl) {
873  if (dimToLvl.getNumResults() != lvlRank)
874  return emitError()
875  << "level-rank mismatch between dimToLvl and lvlTypes: "
876  << dimToLvl.getNumResults() << " != " << lvlRank;
877  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
878  // Symbols can't be inferred but are acceptable.
879  if (!inferRes && dimToLvl.getNumSymbols() == 0)
880  return emitError() << "failed to infer lvlToDim from dimToLvl";
881  if (lvlToDim && (inferRes != lvlToDim))
882  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
883  if (dimRank > lvlRank)
884  return emitError() << "unexpected dimToLvl mapping from " << dimRank
885  << " to " << lvlRank;
886  }
887  if (!dimSlices.empty()) {
888  if (dimSlices.size() != dimRank)
889  return emitError()
890  << "dimension-rank mismatch between dimSlices and dimToLvl: "
891  << dimSlices.size() << " != " << dimRank;
892  // Compiler support for `dimSlices` currently requires that the two
893  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
894  if (dimRank != lvlRank)
895  return emitError()
896  << "dimSlices expected dimension-rank to match level-rank: "
897  << dimRank << " != " << lvlRank;
898  }
899  return success();
900 }
901 
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903  ArrayRef<Size> dimShape, Type elementType,
905  // Check structural integrity. In particular, this ensures that the
906  // level-rank is coherent across all the fields.
907  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908  getPosWidth(), getCrdWidth(), getExplicitVal(),
909  getImplicitVal(), getDimSlices())))
910  return failure();
911  // Check integrity with tensor type specifics. In particular, we
912  // need only check that the dimension-rank of the tensor agrees with
913  // the dimension-rank of the encoding.
914  const Dimension dimRank = dimShape.size();
915  if (dimRank == 0)
916  return emitError() << "expected non-scalar sparse tensor";
917  if (getDimRank() != dimRank)
918  return emitError()
919  << "dimension-rank mismatch between encoding and tensor shape: "
920  << getDimRank() << " != " << dimRank;
921  if (auto expVal = getExplicitVal()) {
922  Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923  if (attrType != elementType) {
924  return emitError() << "explicit value type mismatch between encoding and "
925  << "tensor element type: " << attrType
926  << " != " << elementType;
927  }
928  }
929  if (auto impVal = getImplicitVal()) {
930  Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931  if (attrType != elementType) {
932  return emitError() << "implicit value type mismatch between encoding and "
933  << "tensor element type: " << attrType
934  << " != " << elementType;
935  }
936  // Currently, we only support zero as the implicit value.
937  auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938  auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939  auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940  if ((impFVal && impFVal.getValue().isNonZero()) ||
941  (impIntVal && !impIntVal.getValue().isZero()) ||
942  (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943  impComplexVal.getReal().isNonZero()))) {
944  return emitError() << "implicit value must be zero";
945  }
946  }
947  return success();
948 }
949 
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951  SmallVector<COOSegment> coo = getCOOSegments();
952  assert(coo.size() == 1 || coo.empty());
953  if (!coo.empty() && coo.front().isAoS()) {
954  return coo.front().lvlRange.first;
955  }
956  return getLvlRank();
957 }
958 
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
962  if (getLvlRank() <= 1)
963  return ret;
964 
965  ArrayRef<LevelType> lts = getLvlTypes();
966  Level l = 0;
967  while (l < getLvlRank()) {
968  auto lt = lts[l];
970  auto cur = lts.begin() + l;
971  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972  return !lt.isa<LevelFormat::Singleton>();
973  });
974  unsigned cooLen = std::distance(cur, end);
975  if (cooLen > 1) {
976  // To support mixed SoA/AoS COO, we should break the segment when the
977  // storage scheme changes, for now we faithfully assume that all
978  // consecutive singleton levels have the same storage format as verified
979  // STEA.
980  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982  }
983  l += cooLen;
984  } else {
985  l++;
986  }
987  }
988  return ret;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // SparseTensorType Methods.
993 //===----------------------------------------------------------------------===//
994 
996  bool isUnique) const {
997  if (!hasEncoding())
998  return false;
999  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1000  return false;
1001  for (Level l = startLvl + 1; l < lvlRank; ++l)
1002  if (!isSingletonLvl(l))
1003  return false;
1004  // If isUnique is true, then make sure that the last level is unique,
1005  // that is, when lvlRank == 1, the only compressed level is unique,
1006  // and when lvlRank > 1, the last singleton is unique.
1007  return !isUnique || isUniqueLvl(lvlRank - 1);
1008 }
1009 
1010 RankedTensorType
1012  SmallVector<LevelType> lvlTypes;
1013  lvlTypes.reserve(lvlRank);
1014  // A non-unique compressed level at beginning (unless this is
1015  // also the last level, then it is unique).
1016  lvlTypes.push_back(
1017  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1018  if (lvlRank > 1) {
1019  // Followed by n-2 non-unique singleton levels.
1020  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1021  *buildLevelType(LevelFormat::Singleton, ordered, false));
1022  // Ends by a unique singleton level.
1023  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1024  }
1025  auto enc = SparseTensorEncodingAttr::get(
1026  getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027  getCrdWidth(), getExplicitVal(), getImplicitVal());
1028  return RankedTensorType::get(getDimShape(), getElementType(), enc);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // Convenience Methods.
1033 //===----------------------------------------------------------------------===//
1034 
1035 SparseTensorEncodingAttr
1037  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040  return mdtp.getEncoding();
1041  return nullptr;
1042 }
1043 
1045  MLIRContext *context) {
1046  auto map = static_cast<AffineMap>(dimToLvl);
1047  AffineMap lvlToDim;
1048  // Return an empty lvlToDim when inference is not successful.
1049  if (!map || map.getNumSymbols() != 0) {
1050  lvlToDim = AffineMap();
1051  } else if (map.isPermutation()) {
1052  lvlToDim = inversePermutation(map);
1053  } else if (isBlockSparsity(map)) {
1054  lvlToDim = inverseBlockSparsity(map, context);
1055  }
1056  return lvlToDim;
1057 }
1058 
1060  MLIRContext *context) {
1061  SmallVector<AffineExpr> lvlExprs;
1062  auto numLvls = dimToLvl.getNumResults();
1063  lvlExprs.reserve(numLvls);
1064  // lvlExprComponents stores information of the floordiv and mod operations
1065  // applied to the same dimension, so as to build the lvlToDim map.
1066  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067  for (unsigned i = 0, n = numLvls; i < n; i++) {
1068  auto result = dimToLvl.getResult(i);
1069  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070  if (result.getKind() == AffineExprKind::FloorDiv) {
1071  // Position of the dimension in dimToLvl.
1072  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074  "expected only one floordiv for each dimension");
1075  SmallVector<AffineExpr, 3> components;
1076  // Level variable for floordiv.
1077  components.push_back(getAffineDimExpr(i, context));
1078  // Multiplier.
1079  components.push_back(binOp.getRHS());
1080  // Map key is the position of the dimension.
1081  lvlExprComponents[pos] = components;
1082  } else if (result.getKind() == AffineExprKind::Mod) {
1083  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085  "expected floordiv before mod");
1086  // Add level variable for mod to the same vector
1087  // of the corresponding floordiv.
1088  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089  } else {
1090  assert(false && "expected floordiv or mod");
1091  }
1092  } else {
1093  lvlExprs.push_back(getAffineDimExpr(i, context));
1094  }
1095  }
1096  // Build lvlExprs from lvlExprComponents.
1097  // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098  // would be [il, 2, ii]. It could be used to build the AffineExpr
1099  // i = il * 2 + ii in lvlToDim.
1100  for (auto &components : lvlExprComponents) {
1101  assert(components.second.size() == 3 &&
1102  "expected 3 components to build lvlExprs");
1103  auto mulOp = getAffineBinaryOpExpr(
1104  AffineExprKind::Mul, components.second[0], components.second[1]);
1105  auto addOp =
1106  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107  lvlExprs.push_back(addOp);
1108  }
1109  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110 }
1111 
1113  assert(isBlockSparsity(dimToLvl) &&
1114  "expected dimToLvl to be block sparsity for calling getBlockSize");
1115  SmallVector<unsigned> blockSize;
1116  for (auto result : dimToLvl.getResults()) {
1117  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1118  if (result.getKind() == AffineExprKind::Mod) {
1119  blockSize.push_back(
1120  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1121  }
1122  } else {
1123  blockSize.push_back(0);
1124  }
1125  }
1126  return blockSize;
1127 }
1128 
1130  if (!dimToLvl)
1131  return false;
1132  std::map<unsigned, int64_t> coeffientMap;
1133  bool hasBlock = false;
1134  for (auto result : dimToLvl.getResults()) {
1135  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1136  // Check for "dim op const".
1137  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139  if (!dimOp || !conOp || conOp.getValue() <= 0)
1140  return false;
1141  // Inspect "dim / const" or "dim % const".
1142  auto pos = dimOp.getPosition();
1143  if (binOp.getKind() == AffineExprKind::FloorDiv) {
1144  // Expect only one floordiv for each dimension.
1145  if (coeffientMap.find(pos) != coeffientMap.end())
1146  return false;
1147  // Record coefficient of the floordiv.
1148  coeffientMap[pos] = conOp.getValue();
1149  } else if (binOp.getKind() == AffineExprKind::Mod) {
1150  // Expect floordiv before mod.
1151  if (coeffientMap.find(pos) == coeffientMap.end())
1152  return false;
1153  // Expect mod to have the same coefficient as floordiv.
1154  if (conOp.getValue() != coeffientMap[pos])
1155  return false;
1156  hasBlock = true;
1157  } else {
1158  return false;
1159  }
1160  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1161  auto pos = dimOp.getPosition();
1162  // Expect dim to be unset.
1163  if (coeffientMap.find(pos) != coeffientMap.end())
1164  return false;
1165  coeffientMap[pos] = 0;
1166  } else {
1167  return false;
1168  }
1169  }
1170  return hasBlock;
1171 }
1172 
1174  auto hasNonIdentityMap = [](Value v) {
1175  auto stt = tryGetSparseTensorType(v);
1176  return stt && !stt->isIdentity();
1177  };
1178 
1179  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1180  llvm::any_of(op->getResults(), hasNonIdentityMap);
1181 }
1182 
1183 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1184  if (enc) {
1185  assert(enc.isPermutation() && "Non permutation map not supported");
1186  if (const auto dimToLvl = enc.getDimToLvl())
1187  return dimToLvl.getDimPosition(l);
1188  }
1189  return l;
1190 }
1191 
1192 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1193  if (enc) {
1194  assert(enc.isPermutation() && "Non permutation map not supported");
1195  if (const auto lvlToDim = enc.getLvlToDim())
1196  return lvlToDim.getDimPosition(d);
1197  }
1198  return d;
1199 }
1200 
1201 /// We normalized sparse tensor encoding attribute by always using
1202 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1203 /// as other variants) lead to the same storage specifier type, and stripping
1204 /// irrelevant fields that do not alter the sparse tensor memory layout.
1205 static SparseTensorEncodingAttr
1206 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1208  for (auto lt : enc.getLvlTypes())
1209  lts.push_back(lt.stripStorageIrrelevantProperties());
1210 
1212  enc.getContext(), lts,
1213  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1214  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1215  // Always use `index` for memSize and lvlSize instead of reusing
1216  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1217  // value for different bitwidth, it also avoids casting between index and
1218  // integer (returned by DimOp)
1219  0, 0,
1220  Attribute(), // explicitVal (irrelevant to storage specifier)
1221  Attribute(), // implicitVal (irrelevant to storage specifier)
1222  enc.getDimSlices());
1223 }
1224 
1225 StorageSpecifierType
1226 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1227  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1228 }
1229 
1230 StorageSpecifierType
1231 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1232  MLIRContext *ctx,
1233  SparseTensorEncodingAttr encoding) {
1234  return Base::getChecked(emitError, ctx,
1236 }
1237 
1238 //===----------------------------------------------------------------------===//
1239 // SparseTensorDialect Operations.
1240 //===----------------------------------------------------------------------===//
1241 
1242 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1243  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1244 }
1245 
1246 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1247  const Type etp = getMemRefType(mem).getElementType();
1248  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1249 }
1250 
1251 static LogicalResult verifySparsifierGetterSetter(
1252  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1254  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1255  return op->emitError(
1256  "redundant level argument for querying value memory size");
1257  }
1258 
1259  const auto enc = md.getType().getEncoding();
1260  const Level lvlRank = enc.getLvlRank();
1261 
1262  if (mdKind == StorageSpecifierKind::DimOffset ||
1263  mdKind == StorageSpecifierKind::DimStride)
1264  if (!enc.isSlice())
1265  return op->emitError("requested slice data on non-slice tensor");
1266 
1267  if (mdKind != StorageSpecifierKind::ValMemSize) {
1268  if (!lvl)
1269  return op->emitError("missing level argument");
1270 
1271  const Level l = lvl.value();
1272  if (l >= lvlRank)
1273  return op->emitError("requested level is out of bounds");
1274 
1275  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1276  return op->emitError(
1277  "requested position memory size on a singleton level");
1278  }
1279  return success();
1280 }
1281 
1283  switch (kind) {
1285  return stt.getCrdType();
1287  return stt.getPosType();
1289  return stt.getElementType();
1291  return nullptr;
1292  }
1293  llvm_unreachable("Unrecognizable FieldKind");
1294 }
1295 
1296 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1297  SparseTensorType stt,
1298  RankedTensorType valTp,
1299  TypeRange lvlTps) {
1300  if (requiresStaticShape && !stt.hasStaticDimShape())
1301  return op->emitError("the sparse-tensor must have static shape");
1302  if (!stt.hasEncoding())
1303  return op->emitError("the sparse-tensor must have an encoding attribute");
1304 
1305  // Verifies the trailing COO.
1306  Level cooStartLvl = stt.getAoSCOOStart();
1307  if (cooStartLvl < stt.getLvlRank()) {
1308  // We only supports trailing COO for now, must be the last input.
1309  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1310  // The coordinates should be in shape of <? x rank>
1311  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1312  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1313  op->emitError("input/output trailing COO level-ranks don't match");
1314  }
1315  }
1316 
1317  // Verifies that all types match.
1318  StorageLayout layout(stt.getEncoding());
1319  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1320  return op->emitError("inconsistent number of fields between input/output");
1321 
1322  unsigned idx = 0;
1323  bool misMatch = false;
1324  layout.foreachField([&idx, &misMatch, stt, valTp,
1325  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1326  Level lvl, LevelType lt) -> bool {
1328  return true;
1329 
1330  Type inputTp = nullptr;
1331  if (fKind == SparseTensorFieldKind::ValMemRef) {
1332  inputTp = valTp;
1333  } else {
1334  assert(fid == idx && stt.getLvlType(lvl) == lt);
1335  inputTp = lvlTps[idx++];
1336  }
1337  // The input element type and expected element type should match.
1338  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1339  Type expElemTp = getFieldElemType(stt, fKind);
1340  if (inpElemTp != expElemTp) {
1341  misMatch = true;
1342  return false; // to terminate the iteration
1343  }
1344  return true;
1345  });
1346 
1347  if (misMatch)
1348  return op->emitError("input/output element-types don't match");
1349  return success();
1350 }
1351 
1352 LogicalResult AssembleOp::verify() {
1353  const auto valuesTp = getRankedTensorType(getValues());
1354  const auto lvlsTp = getLevels().getTypes();
1355  const auto resTp = getSparseTensorType(getResult());
1356  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1357 }
1358 
1359 LogicalResult DisassembleOp::verify() {
1360  if (getOutValues().getType() != getRetValues().getType())
1361  return emitError("output values and return value type mismatch");
1362 
1363  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1364  if (ot.getType() != rt.getType())
1365  return emitError("output levels and return levels type mismatch");
1366 
1367  const auto valuesTp = getRankedTensorType(getRetValues());
1368  const auto lvlsTp = getRetLevels().getTypes();
1369  const auto srcTp = getSparseTensorType(getTensor());
1370  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1371 }
1372 
1373 LogicalResult ConvertOp::verify() {
1374  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1375  if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1376  if (tp1.getRank() != tp2.getRank())
1377  return emitError("unexpected conversion mismatch in rank");
1378  auto dstEnc =
1379  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1380  if (dstEnc && dstEnc.isSlice())
1381  return emitError("cannot convert to a sparse tensor slice");
1382 
1383  auto shape1 = tp1.getShape();
1384  auto shape2 = tp2.getShape();
1385  // Accept size matches between the source and the destination type
1386  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1389  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1390  return emitError("unexpected conversion mismatch in dimension ") << d;
1391  return success();
1392  }
1393  }
1394  return emitError("unexpected type in convert");
1395 }
1396 
1397 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1398  if (getType() == getSource().getType())
1399  return getSource();
1400  return {};
1401 }
1402 
1403 bool ConvertOp::needsExtraSort() {
1404  SparseTensorType srcStt = getSparseTensorType(getSource());
1405  SparseTensorType dstStt = getSparseTensorType(getDest());
1406 
1407  // We do not need an extra sort when returning unordered sparse tensors or
1408  // dense tensor since dense tensor support random access.
1409  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1410  return false;
1411 
1412  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1413  srcStt.hasSameDimToLvl(dstStt)) {
1414  return false;
1415  }
1416 
1417  // Source and dest tensors are ordered in different ways. We only do direct
1418  // dense to sparse conversion when the dense input is defined by a sparse
1419  // constant. Note that we can theoretically always directly convert from dense
1420  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1421  // performance.
1422  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1423  if (isa<SparseElementsAttr>(constOp.getValue()))
1424  return false;
1425 
1426  return true;
1427 }
1428 
1429 LogicalResult CrdTranslateOp::verify() {
1430  uint64_t inRank = getEncoder().getLvlRank();
1431  uint64_t outRank = getEncoder().getDimRank();
1432 
1433  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1434  std::swap(inRank, outRank);
1435 
1436  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1437  return emitError("Coordinate rank mismatch with encoding");
1438 
1439  return success();
1440 }
1441 
1442 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1443  SmallVectorImpl<OpFoldResult> &results) {
1444  if (getEncoder().isIdentity()) {
1445  results.assign(getInCrds().begin(), getInCrds().end());
1446  return success();
1447  }
1448  if (getEncoder().isPermutation()) {
1449  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1450  ? getEncoder().getDimToLvl()
1451  : getEncoder().getLvlToDim();
1452  for (AffineExpr exp : perm.getResults())
1453  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1454  return success();
1455  }
1456 
1457  // Fuse dim2lvl/lvl2dim pairs.
1458  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1459  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1460  return v.getDefiningOp() == def;
1461  });
1462  if (!sameDef)
1463  return failure();
1464 
1465  bool oppositeDir = def.getDirection() != getDirection();
1466  bool sameOracle =
1467  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1468  bool sameCount = def.getNumResults() == getInCrds().size();
1469  if (!oppositeDir || !sameOracle || !sameCount)
1470  return failure();
1471 
1472  // The definition produces the coordinates in the same order as the input
1473  // coordinates.
1474  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1475  [](auto valuePair) {
1476  auto [lhs, rhs] = valuePair;
1477  return lhs == rhs;
1478  });
1479 
1480  if (!sameOrder)
1481  return failure();
1482  // l1 = dim2lvl (lvl2dim l0)
1483  // ==> l0
1484  results.append(def.getInCrds().begin(), def.getInCrds().end());
1485  return success();
1486 }
1487 
1488 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1489  int64_t index) {
1490  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1491  return build(builder, state, source, val);
1492 }
1493 
1494 LogicalResult LvlOp::verify() {
1495  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1496  auto stt = getSparseTensorType(getSource());
1497  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1498  emitError("Level index exceeds the rank of the input sparse tensor");
1499  }
1500  return success();
1501 }
1502 
1503 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1504  return getConstantIntValue(getIndex());
1505 }
1506 
1507 Speculation::Speculatability LvlOp::getSpeculatability() {
1508  auto constantIndex = getConstantLvlIndex();
1509  if (!constantIndex)
1511 
1512  assert(constantIndex <
1513  cast<RankedTensorType>(getSource().getType()).getRank());
1515 }
1516 
1517 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1518  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1519  if (!lvlIndex)
1520  return {};
1521 
1522  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1523  auto stt = getSparseTensorType(getSource());
1524  if (lvl >= stt.getLvlRank()) {
1525  // Follows the same convention used by tensor.dim operation. Out of bound
1526  // indices produce undefined behavior but are still valid IR. Don't choke on
1527  // them.
1528  return {};
1529  }
1530 
1531  // Helper lambda to build an IndexAttr.
1532  auto getIndexAttr = [this](int64_t lvlSz) {
1533  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1534  };
1535 
1536  SmallVector<Size> lvlShape = stt.getLvlShape();
1537  if (!ShapedType::isDynamic(lvlShape[lvl]))
1538  return getIndexAttr(lvlShape[lvl]);
1539 
1540  return {};
1541 }
1542 
1543 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1544  SparseTensorEncodingAttr dstEnc, Value source) {
1545  auto srcStt = getSparseTensorType(source);
1546  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1547  SmallVector<int64_t> dstDimShape =
1548  dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1549  auto dstTp =
1550  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1551  return build(odsBuilder, odsState, dstTp, source);
1552 }
1553 
1554 LogicalResult ReinterpretMapOp::verify() {
1555  auto srcStt = getSparseTensorType(getSource());
1556  auto dstStt = getSparseTensorType(getDest());
1557  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1558  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1559 
1560  if (srcLvlTps.size() != dstLvlTps.size())
1561  return emitError("Level rank mismatch between source/dest tensors");
1562 
1563  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1564  if (srcLvlTp != dstLvlTp)
1565  return emitError("Level type mismatch between source/dest tensors");
1566 
1567  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1568  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1569  return emitError("Crd/Pos width mismatch between source/dest tensors");
1570  }
1571 
1572  if (srcStt.getElementType() != dstStt.getElementType())
1573  return emitError("Element type mismatch between source/dest tensors");
1574 
1575  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1576  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1577  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1578  if (srcLvlSz != dstLvlSz) {
1579  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1580  // compatible to <3x4>? For now, we require all the level sizes to be
1581  // *exactly* matched for simplicity.
1582  return emitError("Level size mismatch between source/dest tensors");
1583  }
1584  }
1585 
1586  return success();
1587 }
1588 
1589 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1590  if (getSource().getType() == getDest().getType())
1591  return getSource();
1592 
1593  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1594  // A -> B, B -> A ==> A
1595  if (def.getSource().getType() == getDest().getType())
1596  return def.getSource();
1597  }
1598  return {};
1599 }
1600 
1601 template <typename ToBufferOp>
1602 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1603  OpaqueProperties prop,
1604  RegionRange region,
1606  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1607  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1608  Type elemTp = nullptr;
1609  bool withStride = false;
1610  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1611  elemTp = stt.getPosType();
1612  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1613  std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1614  elemTp = stt.getCrdType();
1615  if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1616  withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1617  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1618  elemTp = stt.getElementType();
1619  }
1620 
1621  assert(elemTp && "unhandled operation.");
1622  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1623  bufShape.push_back(ShapedType::kDynamic);
1624 
1625  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1626  stt.getContext(), ShapedType::kDynamic,
1627  {ShapedType::kDynamic})
1628  : StridedLayoutAttr();
1629  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1630  return success();
1631 }
1632 
1633 LogicalResult ToPositionsOp::verify() {
1634  auto stt = getSparseTensorType(getTensor());
1635  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1636  return emitError("requested level is out of bounds");
1637  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1638  return emitError("unexpected type for positions");
1639  return success();
1640 }
1641 
1642 LogicalResult
1643 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1644  ValueRange ops, DictionaryAttr attr,
1645  OpaqueProperties prop, RegionRange region,
1647  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1648 }
1649 
1650 LogicalResult ToCoordinatesOp::verify() {
1651  auto stt = getSparseTensorType(getTensor());
1652  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1653  return emitError("requested level is out of bounds");
1654  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1655  return emitError("unexpected type for coordinates");
1656  return success();
1657 }
1658 
1659 LogicalResult
1660 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1661  ValueRange ops, DictionaryAttr attr,
1662  OpaqueProperties prop, RegionRange region,
1664  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1665 }
1666 
1667 LogicalResult ToCoordinatesBufferOp::verify() {
1668  auto stt = getSparseTensorType(getTensor());
1669  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1670  return emitError("expected sparse tensor with a COO region");
1671  return success();
1672 }
1673 
1674 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1675  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1676  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1678  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1679  ret);
1680 }
1681 
1682 LogicalResult ToValuesOp::verify() {
1683  auto stt = getSparseTensorType(getTensor());
1684  auto mtp = getMemRefType(getResult());
1685  if (stt.getElementType() != mtp.getElementType())
1686  return emitError("unexpected mismatch in element types");
1687  return success();
1688 }
1689 
1690 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1691  std::optional<Location> loc,
1692  ValueRange ops, DictionaryAttr attr,
1693  OpaqueProperties prop,
1694  RegionRange region,
1696  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1697 }
1698 
1699 LogicalResult ToSliceOffsetOp::verify() {
1700  auto rank = getRankedTensorType(getSlice()).getRank();
1701  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1702  return emitError("requested dimension out of bound");
1703  return success();
1704 }
1705 
1706 LogicalResult ToSliceStrideOp::verify() {
1707  auto rank = getRankedTensorType(getSlice()).getRank();
1708  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1709  return emitError("requested dimension out of bound");
1710  return success();
1711 }
1712 
1713 LogicalResult GetStorageSpecifierOp::verify() {
1714  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1715  getSpecifier(), getOperation());
1716 }
1717 
1718 template <typename SpecifierOp>
1719 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1720  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1721 }
1722 
1723 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1724  const StorageSpecifierKind kind = getSpecifierKind();
1725  const auto lvl = getLevel();
1726  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1727  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1728  return op.getValue();
1729  return {};
1730 }
1731 
1732 LogicalResult SetStorageSpecifierOp::verify() {
1733  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1734  getSpecifier(), getOperation());
1735 }
1736 
1737 template <class T>
1738 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1739  const char *regionName,
1740  TypeRange inputTypes, Type outputType) {
1741  unsigned numArgs = region.getNumArguments();
1742  unsigned expectedNum = inputTypes.size();
1743  if (numArgs != expectedNum)
1744  return op->emitError() << regionName << " region must have exactly "
1745  << expectedNum << " arguments";
1746 
1747  for (unsigned i = 0; i < numArgs; i++) {
1748  Type typ = region.getArgument(i).getType();
1749  if (typ != inputTypes[i])
1750  return op->emitError() << regionName << " region argument " << (i + 1)
1751  << " type mismatch";
1752  }
1753  Operation *term = region.front().getTerminator();
1754  YieldOp yield = dyn_cast<YieldOp>(term);
1755  if (!yield)
1756  return op->emitError() << regionName
1757  << " region must end with sparse_tensor.yield";
1758  if (!yield.hasSingleResult() ||
1759  yield.getSingleResult().getType() != outputType)
1760  return op->emitError() << regionName << " region yield type mismatch";
1761 
1762  return success();
1763 }
1764 
1765 LogicalResult BinaryOp::verify() {
1766  NamedAttrList attrs = (*this)->getAttrs();
1767  Type leftType = getX().getType();
1768  Type rightType = getY().getType();
1769  Type outputType = getOutput().getType();
1770  Region &overlap = getOverlapRegion();
1771  Region &left = getLeftRegion();
1772  Region &right = getRightRegion();
1773 
1774  // Check correct number of block arguments and return type for each
1775  // non-empty region.
1776  if (!overlap.empty()) {
1777  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1778  TypeRange{leftType, rightType}, outputType)))
1779  return failure();
1780  }
1781  if (!left.empty()) {
1782  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1783  outputType)))
1784  return failure();
1785  } else if (getLeftIdentity()) {
1786  if (leftType != outputType)
1787  return emitError("left=identity requires first argument to have the same "
1788  "type as the output");
1789  }
1790  if (!right.empty()) {
1791  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1792  outputType)))
1793  return failure();
1794  } else if (getRightIdentity()) {
1795  if (rightType != outputType)
1796  return emitError("right=identity requires second argument to have the "
1797  "same type as the output");
1798  }
1799  return success();
1800 }
1801 
1802 LogicalResult UnaryOp::verify() {
1803  Type inputType = getX().getType();
1804  Type outputType = getOutput().getType();
1805 
1806  // Check correct number of block arguments and return type for each
1807  // non-empty region.
1808  Region &present = getPresentRegion();
1809  if (!present.empty()) {
1810  if (failed(verifyNumBlockArgs(this, present, "present",
1811  TypeRange{inputType}, outputType)))
1812  return failure();
1813  }
1814  Region &absent = getAbsentRegion();
1815  if (!absent.empty()) {
1816  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1817  outputType)))
1818  return failure();
1819  // Absent branch can only yield invariant values.
1820  Block *absentBlock = &absent.front();
1821  Block *parent = getOperation()->getBlock();
1822  Value absentVal =
1823  cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1824  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1825  if (arg.getOwner() == parent)
1826  return emitError("absent region cannot yield linalg argument");
1827  } else if (Operation *def = absentVal.getDefiningOp()) {
1828  if (!isa<arith::ConstantOp>(def) &&
1829  (def->getBlock() == absentBlock || def->getBlock() == parent))
1830  return emitError("absent region cannot yield locally computed value");
1831  }
1832  }
1833  return success();
1834 }
1835 
1836 bool ConcatenateOp::needsExtraSort() {
1837  SparseTensorType dstStt = getSparseTensorType(*this);
1838  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1839  return false;
1840 
1841  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1842  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1843  });
1844  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1845  // in all input/output buffers, and all input/output buffers have the same
1846  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1847  // CSC matrices along column).
1848  bool directLowerable =
1849  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1850  return !directLowerable;
1851 }
1852 
1853 LogicalResult ConcatenateOp::verify() {
1854  const auto dstTp = getSparseTensorType(*this);
1855  const Dimension concatDim = getDimension();
1856  const Dimension dimRank = dstTp.getDimRank();
1857 
1858  if (getInputs().size() <= 1)
1859  return emitError("Need at least two tensors to concatenate.");
1860 
1861  if (concatDim >= dimRank)
1862  return emitError(llvm::formatv(
1863  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1864  concatDim, dimRank));
1865 
1866  for (const auto &it : llvm::enumerate(getInputs())) {
1867  const auto i = it.index();
1868  const auto srcTp = getSparseTensorType(it.value());
1869  if (srcTp.hasDynamicDimShape())
1870  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1871  const Dimension srcDimRank = srcTp.getDimRank();
1872  if (srcDimRank != dimRank)
1873  return emitError(
1874  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1875  "from the output tensor (rank={2}).",
1876  i, srcDimRank, dimRank));
1877  }
1878 
1879  for (Dimension d = 0; d < dimRank; d++) {
1880  const Size dstSh = dstTp.getDimShape()[d];
1881  if (d == concatDim) {
1882  if (!ShapedType::isDynamic(dstSh)) {
1883  // If we reach here, then all inputs have static shapes. So we
1884  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1885  // to avoid redundant assertions in the loop.
1886  Size sumSz = 0;
1887  for (const auto src : getInputs())
1888  sumSz += getSparseTensorType(src).getDimShape()[d];
1889  // If all dimension are statically known, the sum of all the input
1890  // dimensions should be equal to the output dimension.
1891  if (sumSz != dstSh)
1892  return emitError(
1893  "The concatenation dimension of the output tensor should be the "
1894  "sum of all the concatenation dimensions of the input tensors.");
1895  }
1896  } else {
1897  Size prev = dstSh;
1898  for (const auto src : getInputs()) {
1899  const auto sh = getSparseTensorType(src).getDimShape()[d];
1900  if (!ShapedType::isDynamic(prev) && sh != prev)
1901  return emitError("All dimensions (expect for the concatenating one) "
1902  "should be equal.");
1903  prev = sh;
1904  }
1905  }
1906  }
1907 
1908  return success();
1909 }
1910 
1911 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1912  Value curSize, Value inBuffer, Value value) {
1913  build(builder, result, curSize, inBuffer, value, Value());
1914 }
1915 
1916 LogicalResult PushBackOp::verify() {
1917  if (Value n = getN()) {
1918  std::optional<int64_t> nValue = getConstantIntValue(n);
1919  if (nValue && nValue.value() < 1)
1920  return emitOpError("n must be not less than 1");
1921  }
1922  return success();
1923 }
1924 
1925 LogicalResult CompressOp::verify() {
1926  const auto stt = getSparseTensorType(getTensor());
1927  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1928  return emitOpError("incorrect number of coordinates");
1929  return success();
1930 }
1931 
1932 void ForeachOp::build(
1933  OpBuilder &builder, OperationState &result, Value tensor,
1934  ValueRange initArgs, AffineMapAttr order,
1936  bodyBuilder) {
1937  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1938  // Builds foreach body.
1939  if (!bodyBuilder)
1940  return;
1941  const auto stt = getSparseTensorType(tensor);
1942  const Dimension dimRank = stt.getDimRank();
1943 
1944  // Starts with `dimRank`-many coordinates.
1945  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1946  // Followed by one value.
1947  blockArgTypes.push_back(stt.getElementType());
1948  // Followed by the reduction variables.
1949  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1950 
1951  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1952 
1953  OpBuilder::InsertionGuard guard(builder);
1954  auto &region = *result.regions.front();
1955  Block *bodyBlock =
1956  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1957  bodyBuilder(builder, result.location,
1958  bodyBlock->getArguments().slice(0, dimRank),
1959  bodyBlock->getArguments()[dimRank],
1960  bodyBlock->getArguments().drop_front(dimRank + 1));
1961 }
1962 
1963 LogicalResult ForeachOp::verify() {
1964  const auto t = getSparseTensorType(getTensor());
1965  const Dimension dimRank = t.getDimRank();
1966  const auto args = getBody()->getArguments();
1967 
1968  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1969  return emitError("Level traverse order does not match tensor's level rank");
1970 
1971  if (dimRank + 1 + getInitArgs().size() != args.size())
1972  return emitError("Unmatched number of arguments in the block");
1973 
1974  if (getNumResults() != getInitArgs().size())
1975  return emitError("Mismatch in number of init arguments and results");
1976 
1977  if (getResultTypes() != getInitArgs().getTypes())
1978  return emitError("Mismatch in types of init arguments and results");
1979 
1980  // Cannot mark this const, because the getters aren't.
1981  auto yield = cast<YieldOp>(getBody()->getTerminator());
1982  if (yield.getNumOperands() != getNumResults() ||
1983  yield.getOperands().getTypes() != getResultTypes())
1984  return emitError("Mismatch in types of yield values and results");
1985 
1986  const auto iTp = IndexType::get(getContext());
1987  for (Dimension d = 0; d < dimRank; d++)
1988  if (args[d].getType() != iTp)
1989  emitError(
1990  llvm::formatv("Expecting Index type for argument at index {0}", d));
1991 
1992  const auto elemTp = t.getElementType();
1993  const auto valueTp = args[dimRank].getType();
1994  if (elemTp != valueTp)
1995  emitError(llvm::formatv("Unmatched element type between input tensor and "
1996  "block argument, expected:{0}, got: {1}",
1997  elemTp, valueTp));
1998  return success();
1999 }
2000 
2001 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2002  if (getSparseTensorEncoding(getInputCoo().getType()) ==
2003  getSparseTensorEncoding(getResultCoo().getType()))
2004  return getInputCoo();
2005 
2006  return {};
2007 }
2008 
2009 LogicalResult ReorderCOOOp::verify() {
2010  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2011  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2012 
2013  if (!srcStt.isCOOType() || !dstStt.isCOOType())
2014  emitError("Expected COO sparse tensors only");
2015 
2016  if (!srcStt.hasSameDimToLvl(dstStt))
2017  emitError("Unmatched dim2lvl map between input and result COO");
2018 
2019  if (srcStt.getPosType() != dstStt.getPosType() ||
2020  srcStt.getCrdType() != dstStt.getCrdType() ||
2021  srcStt.getElementType() != dstStt.getElementType())
2022  emitError("Unmatched storage format between input and result COO");
2023 
2024  return success();
2025 }
2026 
2027 LogicalResult ReduceOp::verify() {
2028  Type inputType = getX().getType();
2029  Region &formula = getRegion();
2030  return verifyNumBlockArgs(this, formula, "reduce",
2031  TypeRange{inputType, inputType}, inputType);
2032 }
2033 
2034 LogicalResult SelectOp::verify() {
2035  Builder b(getContext());
2036  Type inputType = getX().getType();
2037  Type boolType = b.getI1Type();
2038  Region &formula = getRegion();
2039  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2040  boolType);
2041 }
2042 
2043 LogicalResult SortOp::verify() {
2044  AffineMap xPerm = getPermMap();
2045  uint64_t nx = xPerm.getNumDims();
2046  if (nx < 1)
2047  emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2048 
2049  if (!xPerm.isPermutation())
2050  emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2051 
2052  // We can't check the size of the buffers when n or buffer dimensions aren't
2053  // compile-time constants.
2054  std::optional<int64_t> cn = getConstantIntValue(getN());
2055  if (!cn)
2056  return success();
2057 
2058  // Verify dimensions.
2059  const auto checkDim = [&](Value v, Size minSize, const char *message) {
2060  const Size sh = getMemRefType(v).getShape()[0];
2061  if (!ShapedType::isDynamic(sh) && sh < minSize)
2062  emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2063  };
2064  uint64_t n = cn.value();
2065  uint64_t ny = 0;
2066  if (auto nyAttr = getNyAttr())
2067  ny = nyAttr.getInt();
2068  checkDim(getXy(), n * (nx + ny),
2069  "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2070  for (Value opnd : getYs())
2071  checkDim(opnd, n, "Expected dimension(y) >= n");
2072 
2073  return success();
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // Sparse Tensor Iteration Operations.
2078 //===----------------------------------------------------------------------===//
2079 
2080 IterSpaceType IteratorType::getIterSpaceType() const {
2081  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2082  getHiLvl());
2083 }
2084 
2085 IteratorType IterSpaceType::getIteratorType() const {
2086  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2087 }
2088 
2089 /// Parses a level range in the form "$lo `to` $hi"
2090 /// or simply "$lo" if $hi - $lo = 1
2091 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2092  Level &lvlHi) {
2093  if (parser.parseInteger(lvlLo))
2094  return failure();
2095 
2096  if (succeeded(parser.parseOptionalKeyword("to"))) {
2097  if (parser.parseInteger(lvlHi))
2098  return failure();
2099  } else {
2100  lvlHi = lvlLo + 1;
2101  }
2102 
2103  if (lvlHi <= lvlLo)
2104  parser.emitError(parser.getNameLoc(),
2105  "expect larger level upper bound than lower bound");
2106 
2107  return success();
2108 }
2109 
2110 /// Parses a level range in the form "$lo `to` $hi"
2111 /// or simply "$lo" if $hi - $lo = 1
2112 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2113  IntegerAttr &lvlHiAttr) {
2114  Level lvlLo, lvlHi;
2115  if (parseLevelRange(parser, lvlLo, lvlHi))
2116  return failure();
2117 
2118  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2119  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2120  return success();
2121 }
2122 
2123 /// Prints a level range in the form "$lo `to` $hi"
2124 /// or simply "$lo" if $hi - $lo = 1
2125 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2126 
2127  if (lo + 1 == hi)
2128  p << lo;
2129  else
2130  p << lo << " to " << hi;
2131 }
2132 
2133 /// Prints a level range in the form "$lo `to` $hi"
2134 /// or simply "$lo" if $hi - $lo = 1
2135 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2136  IntegerAttr lvlHi) {
2137  unsigned lo = lvlLo.getValue().getZExtValue();
2138  unsigned hi = lvlHi.getValue().getZExtValue();
2139  printLevelRange(p, lo, hi);
2140 }
2141 
2142 /// Parses a list of `optional` defined list in the form of
2143 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2144 /// corresponding value is not defined (e.g., to represent an undefined
2145 /// coordinate in the sparse iteration space).
2146 static ParseResult parseOptionalDefinedList(
2147  OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2149  unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2151  unsigned cnt = 0;
2152  ParseResult crdList =
2153  parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2154  if (parser.parseOptionalKeyword("_")) {
2155  if (parser.parseArgument(definedArgs.emplace_back()))
2156  return failure();
2157  definedSet.set(cnt);
2158  }
2159  cnt += 1;
2160  return success();
2161  });
2162 
2163  if (cnt > maxCnt)
2164  return parser.emitError(parser.getNameLoc(),
2165  "parsed more value than expected.");
2166 
2167  if (failed(crdList)) {
2168  return parser.emitError(
2169  parser.getNameLoc(),
2170  "expecting SSA value or \"_\" for level coordinates");
2171  }
2172  assert(definedArgs.size() == definedSet.count());
2173  return success();
2174 }
2175 
2176 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2177  Block::BlockArgListType blocksArgs,
2178  I64BitSet definedSet) {
2179  if (definedSet.empty())
2180  return;
2181 
2182  for (unsigned i = 0; i < size; i++) {
2183  if (definedSet[i]) {
2184  p << blocksArgs.front();
2185  blocksArgs = blocksArgs.drop_front();
2186  } else {
2187  p << "_";
2188  }
2189  if (i != size - 1)
2190  p << ", ";
2191  }
2192  assert(blocksArgs.empty());
2193 }
2194 
2195 static ParseResult
2198  // Parse "at(%crd0, _, ...)"
2199  I64BitSet crdUsedLvlSet;
2200  if (succeeded(parser.parseOptionalKeyword("at")) &&
2201  failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2202  return failure();
2203 
2204  // Always use IndexType for the coordinate.
2205  for (auto &coord : coords)
2206  coord.type = parser.getBuilder().getIndexType();
2207 
2208  // Set the CrdUsedLvl bitset.
2209  state.addAttribute("crdUsedLvls",
2210  parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2211  return success();
2212 }
2213 
2214 static ParseResult
2220 
2221  // Parse "%iters, ... in %spaces, ..."
2222  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2223  parser.parseOperandList(spaces))
2224  return failure();
2225 
2226  if (iterators.size() != spaces.size())
2227  return parser.emitError(
2228  parser.getNameLoc(),
2229  "mismatch in number of sparse iterators and sparse spaces");
2230 
2232  if (failed(parseUsedCoordList(parser, state, coords)))
2233  return failure();
2234  size_t numCrds = coords.size();
2235 
2236  // Parse "iter_args(%arg = %init, ...)"
2237  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2238  if (hasIterArgs)
2239  if (parser.parseAssignmentList(blockArgs, initArgs))
2240  return failure();
2241 
2242  blockArgs.append(coords);
2243 
2244  SmallVector<Type> iterSpaceTps;
2245  // parse ": sparse_tensor.iter_space -> ret"
2246  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2247  return failure();
2248  if (iterSpaceTps.size() != spaces.size())
2249  return parser.emitError(parser.getNameLoc(),
2250  "mismatch in number of iteration space operands "
2251  "and iteration space types");
2252 
2253  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2254  IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2255  if (!spaceTp)
2256  return parser.emitError(parser.getNameLoc(),
2257  "expected sparse_tensor.iter_space type for "
2258  "iteration space operands");
2259  it.type = spaceTp.getIteratorType();
2260  }
2261 
2262  if (hasIterArgs)
2263  if (parser.parseArrowTypeList(state.types))
2264  return failure();
2265 
2266  // Resolves input operands.
2267  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2268  state.operands))
2269  return failure();
2270 
2271  if (hasIterArgs) {
2272  // Strip off leading args that used for coordinates.
2273  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2274  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2275  return parser.emitError(
2276  parser.getNameLoc(),
2277  "mismatch in number of iteration arguments and return values");
2278  }
2279 
2280  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2281  it.type = tp;
2282  if (parser.resolveOperand(init, tp, state.operands))
2283  return failure();
2284  }
2285  }
2286  return success();
2287 }
2288 
2289 static ParseResult
2291  SmallVectorImpl<Value> &spacesVals,
2293 
2294  // Parse "(%spaces, ...)"
2296  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2297  return failure();
2298 
2300  if (failed(parseUsedCoordList(parser, state, coords)))
2301  return failure();
2302  size_t numCrds = coords.size();
2303 
2304  // Parse "iter_args(%arg = %init, ...)"
2306  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2307  if (hasIterArgs)
2308  if (parser.parseAssignmentList(blockArgs, initArgs))
2309  return failure();
2310  blockArgs.append(coords);
2311 
2312  SmallVector<Type> iterSpaceTps;
2313  // parse ": (sparse_tensor.iter_space, ...) -> ret"
2314  if (parser.parseColon() || parser.parseLParen() ||
2315  parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2316  return failure();
2317 
2318  if (iterSpaceTps.size() != spaces.size())
2319  return parser.emitError(parser.getNameLoc(),
2320  "mismatch in number of iteration space operands "
2321  "and iteration space types");
2322 
2323  if (hasIterArgs)
2324  if (parser.parseArrowTypeList(state.types))
2325  return failure();
2326 
2327  // Resolves input sparse iteration spaces.
2328  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2329  spacesVals))
2330  return failure();
2331  state.operands.append(spacesVals);
2332 
2333  if (hasIterArgs) {
2334  // Strip off trailing args that used for coordinates.
2335  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2336  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2337  return parser.emitError(
2338  parser.getNameLoc(),
2339  "mismatch in number of iteration arguments and return values");
2340  }
2341 
2342  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2343  it.type = tp;
2344  if (parser.resolveOperand(init, tp, state.operands))
2345  return failure();
2346  }
2347  }
2348  return success();
2349 }
2350 
2351 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2352  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2353  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2355 
2356  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2357  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2358  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2359  adaptor.getHiLvl()));
2360  return success();
2361 }
2362 
2363 LogicalResult ExtractIterSpaceOp::verify() {
2364  if (getLoLvl() >= getHiLvl())
2365  return emitOpError("expected smaller level low than level high");
2366 
2367  TypedValue<IteratorType> pIter = getParentIter();
2368  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2369  return emitOpError(
2370  "parent iterator should be specified iff level lower bound equals 0");
2371  }
2372 
2373  if (pIter) {
2374  IterSpaceType spaceTp = getExtractedSpace().getType();
2375  if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2376  return emitOpError(
2377  "mismatch in parent iterator encoding and iteration space encoding.");
2378 
2379  if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2380  return emitOpError("parent iterator should be used to extract an "
2381  "iteration space from a consecutive level.");
2382  }
2383 
2384  return success();
2385 }
2386 
2387 LogicalResult ExtractValOp::verify() {
2388  auto stt = getSparseTensorType(getTensor());
2389  auto itTp = getIterator().getType();
2390 
2391  if (stt.getEncoding() != itTp.getEncoding())
2392  return emitOpError("mismatch in tensor encoding and iterator encoding.");
2393 
2394  if (stt.getLvlRank() != itTp.getHiLvl())
2395  return emitOpError("must use last-level iterator to extract values. ");
2396 
2397  return success();
2398 }
2399 
2400 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2402 
2403  LogicalResult matchAndRewrite(IterateOp iterateOp,
2404  PatternRewriter &rewriter) const override {
2405  I64BitSet newUsedLvls(0);
2406  llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2407  for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2408  if (auto crd = iterateOp.getLvlCrd(i)) {
2409  if (crd->getUsers().empty())
2410  toRemove.set(crd->getArgNumber());
2411  else
2412  newUsedLvls.set(i);
2413  }
2414  }
2415 
2416  // All coordinates are used.
2417  if (toRemove.none())
2418  return failure();
2419 
2420  rewriter.startOpModification(iterateOp);
2421  iterateOp.setCrdUsedLvls(newUsedLvls);
2422  iterateOp.getBody()->eraseArguments(toRemove);
2423  rewriter.finalizeOpModification(iterateOp);
2424  return success();
2425  }
2426 };
2427 
2428 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2429  mlir::MLIRContext *context) {
2430  results.add<RemoveUnusedLvlCrds>(context);
2431 }
2432 
2433 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2434  Value iterSpace, ValueRange initArgs) {
2435  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2436  // All ones.
2437  I64BitSet set((1 << rank) - 1);
2438  return build(builder, odsState, iterSpace, initArgs, set);
2439 }
2440 
2441 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2442  Value iterSpace, ValueRange initArgs,
2443  I64BitSet crdUsedLvls) {
2444  OpBuilder::InsertionGuard guard(builder);
2445 
2446  odsState.addOperands(iterSpace);
2447  odsState.addOperands(initArgs);
2448  odsState.getOrAddProperties<Properties>().crdUsedLvls =
2449  builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2450  Region *bodyRegion = odsState.addRegion();
2451  odsState.addTypes(initArgs.getTypes());
2452  Block *bodyBlock = builder.createBlock(bodyRegion);
2453 
2454  // Starts with a list of user-provided loop arguments.
2455  for (Value v : initArgs)
2456  bodyBlock->addArgument(v.getType(), v.getLoc());
2457 
2458  // Follows by a list of used coordinates.
2459  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2460  bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2461 
2462  // Ends with sparse iterator
2463  bodyBlock->addArgument(
2464  llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2465  odsState.location);
2466 }
2467 
2468 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2469  OpAsmParser::Argument iterator;
2471 
2472  SmallVector<OpAsmParser::Argument> iters, iterArgs;
2473  if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2474  return failure();
2475  if (iters.size() != 1)
2476  return parser.emitError(parser.getNameLoc(),
2477  "expected only one iterator/iteration space");
2478 
2479  iterArgs.append(iters);
2480  Region *body = result.addRegion();
2481  if (parser.parseRegion(*body, iterArgs))
2482  return failure();
2483 
2484  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2485 
2486  // Parse the optional attribute list.
2487  if (parser.parseOptionalAttrDict(result.attributes))
2488  return failure();
2489 
2490  return success();
2491 }
2492 
2493 /// Prints the initialization list in the form of
2494 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2495 /// where 'inner' values are assumed to be region arguments and 'outer' values
2496 /// are regular SSA values.
2498  Block::BlockArgListType blocksArgs,
2499  ValueRange initializers,
2500  StringRef prefix = "") {
2501  assert(blocksArgs.size() == initializers.size() &&
2502  "expected same length of arguments and initializers");
2503  if (initializers.empty())
2504  return;
2505 
2506  p << prefix << '(';
2507  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2508  p << std::get<0>(it) << " = " << std::get<1>(it);
2509  });
2510  p << ")";
2511 }
2512 
2513 template <typename SparseLoopOp>
2514 static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2515  if (op.getInitArgs().size() != op.getNumResults()) {
2516  return op.emitOpError(
2517  "mismatch in number of loop-carried values and defined values");
2518  }
2519  if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2520  return op.emitOpError("required out-of-bound coordinates");
2521 
2522  return success();
2523 }
2524 
2525 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2526 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2527 
2528 void IterateOp::print(OpAsmPrinter &p) {
2529  p << " " << getIterator() << " in " << getIterSpace();
2530  if (!getCrdUsedLvls().empty()) {
2531  p << " at(";
2532  printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2533  p << ")";
2534  }
2535  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2536 
2537  p << " : " << getIterSpace().getType() << " ";
2538  if (!getInitArgs().empty())
2539  p.printArrowTypeList(getInitArgs().getTypes());
2540 
2541  p << " ";
2542  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2543  /*printBlockTerminators=*/!getInitArgs().empty());
2544 }
2545 
2546 LogicalResult IterateOp::verifyRegions() {
2547  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2548  return emitOpError("mismatch in iterator and iteration space type");
2549  if (getNumRegionIterArgs() != getNumResults())
2550  return emitOpError(
2551  "mismatch in number of basic block args and defined values");
2552 
2553  auto initArgs = getInitArgs();
2554  auto iterArgs = getRegionIterArgs();
2555  auto yieldVals = getYieldedValues();
2556  auto opResults = getResults();
2557  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2558  opResults.size()})) {
2559  return emitOpError() << "number mismatch between iter args and results.";
2560  }
2561 
2562  for (auto [i, init, iter, yield, ret] :
2563  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2564  if (init.getType() != ret.getType())
2565  return emitOpError() << "types mismatch between " << i
2566  << "th iter operand and defined value";
2567  if (iter.getType() != ret.getType())
2568  return emitOpError() << "types mismatch between " << i
2569  << "th iter region arg and defined value";
2570  if (yield.getType() != ret.getType())
2571  return emitOpError() << "types mismatch between " << i
2572  << "th yield value and defined value";
2573  }
2574 
2575  return success();
2576 }
2577 
2578 /// OpInterfaces' methods implemented by IterateOp.
2579 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2580 
2581 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2582  return getInitArgsMutable();
2583 }
2584 
2585 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2586  return getRegion().getArguments().take_front(getNumRegionIterArgs());
2587 }
2588 
2589 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2590  return cast<sparse_tensor::YieldOp>(
2591  getRegion().getBlocks().front().getTerminator())
2592  .getResultsMutable();
2593 }
2594 
2595 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2596 
2597 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2598  return getInitArgs();
2599 }
2600 
2601 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2603  // Both the operation itself and the region may be branching into the body
2604  // or back into the operation itself.
2605  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2606  // It is possible for loop not to enter the body.
2607  regions.push_back(RegionSuccessor(getResults()));
2608 }
2609 
2610 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2611  ValueRange iterSpaces, ValueRange initArgs,
2612  unsigned numCases) {
2613  unsigned rank =
2614  cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2615  // All ones.
2616  I64BitSet set((1 << rank) - 1);
2617  // Generates all-zero case bits (they only serve as placeholders), which are
2618  // supposed to be overriden later. We need to preallocate all the regions as
2619  // mlir::Region cannot be dynamically added later after the operation is
2620  // created.
2621  SmallVector<int64_t> caseBits(numCases, 0);
2622  ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2623  return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2624  initArgs, set, cases,
2625  /*caseRegionsCount=*/numCases);
2626 }
2627 
2628 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2629 
2630  SmallVector<Value> spaces;
2631  // The block argument list of each regions, it is arranged in the order of
2632  // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2634  if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2635  return failure();
2636 
2637  result.addAttribute("operandSegmentSizes",
2639  {static_cast<int32_t>(spaces.size()),
2640  static_cast<int32_t>(result.types.size())}));
2641 
2642  SmallVector<Attribute> cases;
2643  while (succeeded(parser.parseOptionalKeyword("case"))) {
2644  // Parse one region per case.
2645  I64BitSet definedItSet;
2647  if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2648  spaces.size(), OpAsmParser::Delimiter::None))
2649  return failure();
2650 
2651  cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2652 
2653  for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2654  // Resolve the iterator type based on the iteration space type.
2655  auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2656  definedIts[i].type = spaceTp.getIteratorType();
2657  }
2658  definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2659  Region *body = result.addRegion();
2660  if (parser.parseRegion(*body, definedIts))
2661  return failure();
2662 
2663  CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2664  }
2665 
2666  result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2667 
2668  // Parse the optional attribute list.
2669  if (parser.parseOptionalAttrDict(result.attributes))
2670  return failure();
2671 
2672  return success();
2673 }
2674 
2676  p << " (";
2677  llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2678  p << ")";
2679 
2680  if (!getCrdUsedLvls().empty()) {
2681  p << " at(";
2682  printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2683  p << ")";
2684  }
2685 
2686  printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2687 
2688  p << " : (" << getIterSpaces().getTypes() << ")";
2689  if (!getInitArgs().empty())
2690  p.printArrowTypeList(getInitArgs().getTypes());
2691 
2692  for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2693  p.printNewline();
2694  p << "case ";
2695  printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2696  getRegionDefinedSpace(idx));
2697  p << " ";
2698  p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2699  /*printBlockTerminators=*/!getInitArgs().empty());
2700  }
2701 }
2702 
2703 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2704  return cast<sparse_tensor::YieldOp>(
2705  getRegion(regionIdx).getBlocks().front().getTerminator())
2706  .getResults();
2707 }
2708 
2709 LogicalResult CoIterateOp::verifyRegions() {
2710  for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2711  if (getNumRegionIterArgs() != getNumResults())
2712  return emitOpError(
2713  "mismatch in number of basic block args and defined values");
2714 
2715  auto initArgs = getInitArgs();
2716  auto iterArgs = getRegionIterArgs(r);
2717  auto yieldVals = getYieldedValues(r);
2718  auto opResults = getResults();
2719  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2720  opResults.size()})) {
2721  return emitOpError()
2722  << "number mismatch between iter args and results on " << r
2723  << "th region";
2724  }
2725 
2726  for (auto [i, init, iter, yield, ret] :
2727  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2728  if (init.getType() != ret.getType())
2729  return emitOpError()
2730  << "types mismatch between " << i
2731  << "th iter operand and defined value on " << r << "th region";
2732  if (iter.getType() != ret.getType())
2733  return emitOpError() << "types mismatch between " << i
2734  << "th iter region arg and defined value on " << r
2735  << "th region";
2736  if (yield.getType() != ret.getType())
2737  return emitOpError()
2738  << "types mismatch between " << i
2739  << "th yield value and defined value on " << r << "th region";
2740  }
2741  }
2742 
2743  auto cases = getRegionDefinedSpaces();
2744  llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2745  if (set.size() != getNumRegions())
2746  return emitOpError("contains duplicated cases.");
2747 
2748  return success();
2749 }
2750 
2751 SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2753  I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2754  for (Region &r : getCaseRegions())
2755  if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2756  ret.push_back(&r);
2757 
2758  return ret;
2759 }
2760 
2761 //===----------------------------------------------------------------------===//
2762 // Sparse Tensor Dialect Setups.
2763 //===----------------------------------------------------------------------===//
2764 
2765 /// Materialize a single constant operation from a given attribute value with
2766 /// the desired resultant type.
2768  Attribute value, Type type,
2769  Location loc) {
2770  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2771  return op;
2772  return nullptr;
2773 }
2774 
2775 namespace {
2776 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2778 
2779  AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2780  if (isa<SparseTensorEncodingAttr>(attr)) {
2781  os << "sparse";
2782  return AliasResult::OverridableAlias;
2783  }
2784  return AliasResult::NoAlias;
2785  }
2786 };
2787 } // namespace
2788 
2789 void SparseTensorDialect::initialize() {
2790  addInterface<SparseTensorAsmDialectInterface>();
2791  addAttributes<
2792 #define GET_ATTRDEF_LIST
2793 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2794  >();
2795  addTypes<
2796 #define GET_TYPEDEF_LIST
2797 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2798  >();
2799  addOperations<
2800 #define GET_OP_LIST
2801 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2802  >();
2803  declarePromisedInterfaces<
2804  bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2805  NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2806  ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2807 }
2808 
2809 #define GET_OP_CLASSES
2810 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2811 
2812 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:369
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:625
The possible results of an alias query.
Definition: AliasAnalysis.h:26
@ NoAlias
The two locations do not alias at all.
Definition: AliasAnalysis.h:34
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:191
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:250
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:140
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:99
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:300
IndexType getIndexType()
Definition: Builders.cpp:83
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:449
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:59
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:61
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
bool isOrderedLT(LevelType lt)
Definition: Enums.h:425
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:442
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
llvm::hash_code hash_value(LevelType lt)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:160
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:402
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
Definition: SparseTensor.h:50
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299