MLIR  20.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
17 
22 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34 
35 // Forward declarations, following custom print/parsing methods are referenced
36 // by the generated code for SparseTensorTypes.td.
37 static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
42 
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45 
46 using namespace mlir;
47 using namespace mlir::sparse_tensor;
48 
49 // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50 // well.
51 namespace mlir::sparse_tensor {
52 llvm::hash_code hash_value(LevelType lt) {
53  return llvm::hash_value(static_cast<uint64_t>(lt));
54 }
55 } // namespace mlir::sparse_tensor
56 
57 //===----------------------------------------------------------------------===//
58 // Local Convenience Methods.
59 //===----------------------------------------------------------------------===//
60 
61 static constexpr bool acceptBitWidth(unsigned bitWidth) {
62  switch (bitWidth) {
63  case 0:
64  case 8:
65  case 16:
66  case 32:
67  case 64:
68  return true;
69  default:
70  return false;
71  }
72 }
73 
74 static SmallVector<Size>
75 getSparseFieldShape(const SparseTensorEncodingAttr enc,
76  std::optional<ArrayRef<int64_t>> dimShape) {
77  assert(enc);
78  // With only encoding, we can not determine the static shape for leading
79  // batch levels, we therefore return a dynamic shape memref instead.
80  SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81  if (dimShape.has_value()) {
82  // If the actual tensor shape is provided, we can then refine the leading
83  // batch dimension.
84  SmallVector<int64_t> lvlShape =
85  enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86  memrefShape.assign(lvlShape.begin(),
87  lvlShape.begin() + enc.getBatchLvlRank());
88  }
89  // Another dynamic dimension to store the sparse level.
90  memrefShape.push_back(ShapedType::kDynamic);
91  return memrefShape;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // SparseTensorDialect StorageLayout.
96 //===----------------------------------------------------------------------===//
97 
98 static constexpr Level kInvalidLevel = -1u;
99 static constexpr Level kInvalidFieldIndex = -1u;
100 static constexpr FieldIndex kDataFieldStartingIdx = 0;
101 
104  LevelType)>
105  callback) const {
106  const auto lvlTypes = enc.getLvlTypes();
107  const Level lvlRank = enc.getLvlRank();
108  SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
110 
111  ArrayRef cooSegsRef = cooSegs;
112  // Per-level storage.
113  for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114  const auto lt = lvlTypes[l];
115  if (isWithPosLT(lt)) {
116  if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117  return;
118  }
119  if (isWithCrdLT(lt)) {
120  if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121  return;
122  }
123  if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124  if (!cooSegsRef.front().isSoA) {
125  // AoS COO, all singletons are fused into one memrefs. Skips the entire
126  // COO segement.
127  l = cooSegsRef.front().lvlRange.second;
128  } else {
129  // SoA COO, each singleton level has one memref.
130  l++;
131  }
132  // Expire handled COO segment.
133  cooSegsRef = cooSegsRef.drop_front();
134  } else {
135  // Non COO levels.
136  l++;
137  }
138  }
139  // The values array.
140  if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
142  return;
143  // Put metadata at the end.
144  if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
146  return;
147 }
148 
150  SparseTensorType stt,
152  LevelType)>
153  callback) {
154  assert(stt.hasEncoding());
155 
156  SmallVector<int64_t> memrefShape =
158 
159  const Type specType = StorageSpecifierType::get(stt.getEncoding());
160  // memref<[batch] x ? x pos> positions
161  const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162  // memref<[batch] x ? x crd> coordinates
163  const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164  // memref<[batch] x ? x eltType> values
165  const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166 
167  StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
168  callback](FieldIndex fieldIdx,
169  SparseTensorFieldKind fieldKind,
170  Level lvl, LevelType lt) -> bool {
171  switch (fieldKind) {
173  return callback(specType, fieldIdx, fieldKind, lvl, lt);
175  return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
177  return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
179  return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180  };
181  llvm_unreachable("unrecognized field kind");
182  });
183 }
184 
185 unsigned StorageLayout::getNumFields() const {
186  unsigned numFields = 0;
188  LevelType) -> bool {
189  numFields++;
190  return true;
191  });
192  return numFields;
193 }
194 
196  unsigned numFields = 0; // one value memref
198  LevelType) -> bool {
199  if (fidx >= kDataFieldStartingIdx)
200  numFields++;
201  return true;
202  });
203  numFields -= 1; // the last field is StorageSpecifier
204  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205  return numFields;
206 }
207 
208 std::pair<FieldIndex, unsigned>
210  std::optional<Level> lvl) const {
211  FieldIndex fieldIdx = kInvalidFieldIndex;
212  unsigned stride = 1;
213  if (kind == SparseTensorFieldKind::CrdMemRef) {
214  assert(lvl.has_value());
215  const Level cooStart = enc.getAoSCOOStart();
216  const Level lvlRank = enc.getLvlRank();
217  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218  lvl = cooStart;
219  stride = lvlRank - cooStart;
220  }
221  }
222  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
223  SparseTensorFieldKind fKind, Level fLvl,
224  LevelType lt) -> bool {
225  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227  fieldIdx = fIdx;
228  // Returns false to break the iteration.
229  return false;
230  }
231  return true;
232  });
233  assert(fieldIdx != kInvalidFieldIndex);
234  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // SparseTensorDialect Attribute Methods.
239 //===----------------------------------------------------------------------===//
240 
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242  return isDynamic(v) ? std::nullopt
243  : std::make_optional(static_cast<uint64_t>(v));
244 }
245 
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247  return getStatic(getOffset());
248 }
249 
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251  return getStatic(getStride());
252 }
253 
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255  return getStatic(getSize());
256 }
257 
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259  return isDynamic(getOffset()) && isDynamic(getStride()) &&
260  isDynamic(getSize());
261 }
262 
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264  return isDynamic(v) ? "?" : std::to_string(v);
265 }
266 
267 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269  os << '(';
270  os << getStaticString(getOffset());
271  os << ", ";
272  os << getStaticString(getSize());
273  os << ", ";
274  os << getStaticString(getStride());
275  os << ')';
276 }
277 
278 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279  print(printer.getStream());
280 }
281 
282 static ParseResult parseOptionalStaticSlice(int64_t &result,
283  AsmParser &parser) {
284  auto parseResult = parser.parseOptionalInteger(result);
285  if (parseResult.has_value()) {
286  if (parseResult.value().succeeded() && result < 0) {
287  parser.emitError(
288  parser.getCurrentLocation(),
289  "expect positive value or ? for slice offset/size/stride");
290  return failure();
291  }
292  return parseResult.value();
293  }
294 
295  // Else, and '?' which represented dynamic slice
296  result = SparseTensorDimSliceAttr::kDynamic;
297  return parser.parseQuestion();
298 }
299 
301  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302 
303  if (failed(parser.parseLParen()) ||
304  failed(parseOptionalStaticSlice(offset, parser)) ||
305  failed(parser.parseComma()) ||
306  failed(parseOptionalStaticSlice(size, parser)) ||
307  failed(parser.parseComma()) ||
308  failed(parseOptionalStaticSlice(stride, parser)) ||
309  failed(parser.parseRParen()))
310  return {};
311 
312  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313  offset, size, stride);
314 }
315 
316 LogicalResult
318  int64_t offset, int64_t size, int64_t stride) {
319  if (!isDynamic(offset) && offset < 0)
320  return emitError() << "expect non-negative value or ? for slice offset";
321  if (!isDynamic(size) && size <= 0)
322  return emitError() << "expect positive value or ? for slice size";
323  if (!isDynamic(stride) && stride <= 0)
324  return emitError() << "expect positive value or ? for slice stride";
325  return success();
326 }
327 
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
332  getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333  getCrdWidth(), getExplicitVal(), getImplicitVal());
334 }
335 
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339 }
340 
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342  return withDimToLvl(AffineMap());
343 }
344 
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347  unsigned crdWidth) const {
348  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
350  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351  crdWidth, getExplicitVal(), getImplicitVal());
352 }
353 
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355  return withBitWidths(0, 0);
356 }
357 
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
362  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363  getCrdWidth(), explicitVal, getImplicitVal());
364 }
365 
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367  return withExplicitVal(Attribute());
368 }
369 
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
374  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375  getCrdWidth(), getExplicitVal(), implicitVal);
376 }
377 
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379  return withImplicitVal(Attribute());
380 }
381 
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
385  getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386  getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 }
388 
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391 }
392 
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394  ArrayRef<LevelType> lvlTypes = getLvlTypes();
395  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396  return std::distance(lastBatch, lvlTypes.rend());
397 }
398 
400  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401 }
402 
403 bool SparseTensorEncodingAttr::isAllOrdered() const {
404  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405 }
406 
407 Type SparseTensorEncodingAttr::getCrdElemType() const {
408  if (!getImpl())
409  return nullptr;
410  if (getCrdWidth())
411  return IntegerType::get(getContext(), getCrdWidth());
412  return IndexType::get(getContext());
413 }
414 
415 Type SparseTensorEncodingAttr::getPosElemType() const {
416  if (!getImpl())
417  return nullptr;
418  if (getPosWidth())
419  return IntegerType::get(getContext(), getPosWidth());
420  return IndexType::get(getContext());
421 }
422 
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424  std::optional<ArrayRef<int64_t>> dimShape) const {
425  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426  return MemRefType::get(shape, getCrdElemType());
427 }
428 
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430  std::optional<ArrayRef<int64_t>> dimShape) const {
431  SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432  return MemRefType::get(shape, getPosElemType());
433 }
434 
435 bool SparseTensorEncodingAttr::isIdentity() const {
436  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437 }
438 
440  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 }
442 
443 Dimension SparseTensorEncodingAttr::getDimRank() const {
444  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445  const auto dimToLvl = getDimToLvl();
446  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447 }
448 
449 Level SparseTensorEncodingAttr::getLvlRank() const {
450  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451  return getLvlTypes().size();
452 }
453 
454 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455  if (!getImpl())
456  return LevelFormat::Batch;
457  assert(l < getLvlRank() && "Level is out of bounds");
458  return getLvlTypes()[l];
459 }
460 
461 bool SparseTensorEncodingAttr::isSlice() const {
462  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463  return !getDimSlices().empty();
464 }
465 
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468  assert(isSlice() && "Is not a slice");
469  const auto dimSlices = getDimSlices();
470  assert(dim < dimSlices.size() && "Dimension is out of bounds");
471  return dimSlices[dim];
472 }
473 
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476  return getDimSlice(dim).getStaticOffset();
477 }
478 
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481  return getDimSlice(dim).getStaticStride();
482 }
483 
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486  return getStaticDimSliceOffset(toDim(*this, lvl));
487 }
488 
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491  return getStaticDimSliceStride(toDim(*this, lvl));
492 }
493 
495 SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496  CrdTransDirectionKind dir) const {
497  if (isIdentity())
498  return SmallVector<int64_t>(srcShape);
499 
501  unsigned rank =
502  dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503  ret.reserve(rank);
504 
505  if (isPermutation()) {
506  for (unsigned r = 0; r < rank; r++) {
507  unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508  : toLvl(*this, r);
509  ret.push_back(srcShape[trans]);
510  }
511  return ret;
512  }
513 
514  // Handle non-permutation maps.
515  AffineMap transMap =
516  dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 
519  dimRep.reserve(srcShape.size());
520  for (int64_t sz : srcShape) {
521  if (!ShapedType::isDynamic(sz)) {
522  // Push back the max coordinate for the given dimension/level size.
523  dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524  } else {
525  // A dynamic size, use a AffineDimExpr to symbolize the value.
526  dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527  }
528  };
529 
530  for (AffineExpr exp : transMap.getResults()) {
531  // Do constant propagation on the affine map.
532  AffineExpr evalExp =
533  simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534  // use llvm namespace here to avoid ambiguity
535  if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536  ret.push_back(c.getValue() + 1);
537  } else {
538  if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539  mod && mod.getKind() == AffineExprKind::Mod) {
540  // We can still infer a static bound for expressions in form
541  // "d % constant" since d % constant \in [0, constant).
542  if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543  ret.push_back(bound.getValue());
544  continue;
545  }
546  }
547  ret.push_back(ShapedType::kDynamic);
548  }
549  }
550  assert(ret.size() == rank);
551  return ret;
552 }
553 
555 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556  ValueRange crds,
557  CrdTransDirectionKind dir) const {
558  if (!getImpl())
559  return crds;
560 
561  SmallVector<Type> retType(
562  dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563  builder.getIndexType());
564  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565  return transOp.getOutCrds();
566 }
567 
569  // Open "<{" part.
570  if (failed(parser.parseLess()))
571  return {};
572  if (failed(parser.parseLBrace()))
573  return {};
574 
575  // Process the data from the parsed dictionary value into struct-like data.
576  SmallVector<LevelType> lvlTypes;
578  AffineMap dimToLvl = {};
579  AffineMap lvlToDim = {};
580  unsigned posWidth = 0;
581  unsigned crdWidth = 0;
582  Attribute explicitVal;
583  Attribute implicitVal;
584  StringRef attrName;
585  SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586  "explicitVal", "implicitVal"};
587  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588  // Detect admissible keyword.
589  auto *it = find(keys, attrName);
590  if (it == keys.end()) {
591  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592  return {};
593  }
594  unsigned keyWordIndex = it - keys.begin();
595  // Consume the `=` after keys
596  if (failed(parser.parseEqual()))
597  return {};
598  // Dispatch on keyword.
599  switch (keyWordIndex) {
600  case 0: { // map
601  ir_detail::DimLvlMapParser cParser(parser);
602  auto res = cParser.parseDimLvlMap();
603  if (failed(res))
604  return {};
605  const auto &dlm = *res;
606 
607  const Level lvlRank = dlm.getLvlRank();
608  for (Level lvl = 0; lvl < lvlRank; lvl++)
609  lvlTypes.push_back(dlm.getLvlType(lvl));
610 
611  const Dimension dimRank = dlm.getDimRank();
612  for (Dimension dim = 0; dim < dimRank; dim++)
613  dimSlices.push_back(dlm.getDimSlice(dim));
614  // NOTE: the old syntax requires an all-or-nothing approach to
615  // `dimSlices`; therefore, if any slice actually exists then we need
616  // to convert null-DSA into default/nop DSA.
617  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618  return static_cast<bool>(slice.getImpl());
619  };
620  if (llvm::any_of(dimSlices, isDefined)) {
621  const auto defaultSlice =
623  for (Dimension dim = 0; dim < dimRank; dim++)
624  if (!isDefined(dimSlices[dim]))
625  dimSlices[dim] = defaultSlice;
626  } else {
627  dimSlices.clear();
628  }
629 
630  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631  lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632  break;
633  }
634  case 1: { // posWidth
635  Attribute attr;
636  if (failed(parser.parseAttribute(attr)))
637  return {};
638  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639  if (!intAttr) {
640  parser.emitError(parser.getNameLoc(),
641  "expected an integral position bitwidth");
642  return {};
643  }
644  posWidth = intAttr.getInt();
645  break;
646  }
647  case 2: { // crdWidth
648  Attribute attr;
649  if (failed(parser.parseAttribute(attr)))
650  return {};
651  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652  if (!intAttr) {
653  parser.emitError(parser.getNameLoc(),
654  "expected an integral index bitwidth");
655  return {};
656  }
657  crdWidth = intAttr.getInt();
658  break;
659  }
660  case 3: { // explicitVal
661  Attribute attr;
662  if (failed(parser.parseAttribute(attr)))
663  return {};
664  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665  explicitVal = result;
666  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667  explicitVal = result;
668  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669  explicitVal = result;
670  } else {
671  parser.emitError(parser.getNameLoc(),
672  "expected a numeric value for explicitVal");
673  return {};
674  }
675  break;
676  }
677  case 4: { // implicitVal
678  Attribute attr;
679  if (failed(parser.parseAttribute(attr)))
680  return {};
681  if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682  implicitVal = result;
683  } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684  implicitVal = result;
685  } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686  implicitVal = result;
687  } else {
688  parser.emitError(parser.getNameLoc(),
689  "expected a numeric value for implicitVal");
690  return {};
691  }
692  break;
693  }
694  } // switch
695  // Only last item can omit the comma.
696  if (parser.parseOptionalComma().failed())
697  break;
698  }
699 
700  // Close "}>" part.
701  if (failed(parser.parseRBrace()))
702  return {};
703  if (failed(parser.parseGreater()))
704  return {};
705 
706  // Construct struct-like storage for attribute.
707  if (!lvlToDim || lvlToDim.isEmpty()) {
708  lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709  }
710  return parser.getChecked<SparseTensorEncodingAttr>(
711  parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712  explicitVal, implicitVal, dimSlices);
713 }
714 
715 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716  auto map = static_cast<AffineMap>(getDimToLvl());
717  // Empty affine map indicates identity map
718  if (!map)
719  map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720  printer << "<{ map = ";
721  printSymbols(map, printer);
722  printer << '(';
723  printDimensions(map, printer, getDimSlices());
724  printer << ") -> (";
725  printLevels(map, printer, getLvlTypes());
726  printer << ')';
727  // Print remaining members only for non-default values.
728  if (getPosWidth())
729  printer << ", posWidth = " << getPosWidth();
730  if (getCrdWidth())
731  printer << ", crdWidth = " << getCrdWidth();
732  if (getExplicitVal()) {
733  printer << ", explicitVal = " << getExplicitVal();
734  }
735  if (getImplicitVal())
736  printer << ", implicitVal = " << getImplicitVal();
737  printer << " }>";
738 }
739 
740 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741  AsmPrinter &printer) const {
742  if (map.getNumSymbols() == 0)
743  return;
744  printer << '[';
745  for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746  printer << 's' << i << ", ";
747  if (map.getNumSymbols() >= 1)
748  printer << 's' << map.getNumSymbols() - 1;
749  printer << ']';
750 }
751 
752 void SparseTensorEncodingAttr::printDimensions(
753  AffineMap &map, AsmPrinter &printer,
754  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755  if (!dimSlices.empty()) {
756  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757  printer << 'd' << i << " : " << dimSlices[i] << ", ";
758  if (map.getNumDims() >= 1) {
759  printer << 'd' << map.getNumDims() - 1 << " : "
760  << dimSlices[map.getNumDims() - 1];
761  }
762  } else {
763  for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764  printer << 'd' << i << ", ";
765  if (map.getNumDims() >= 1)
766  printer << 'd' << map.getNumDims() - 1;
767  }
768 }
769 
770 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771  ArrayRef<LevelType> lvlTypes) const {
772  for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773  map.getResult(i).print(printer.getStream());
774  printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775  }
776  if (map.getNumResults() >= 1) {
777  auto lastIndex = map.getNumResults() - 1;
778  map.getResult(lastIndex).print(printer.getStream());
779  printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780  }
781 }
782 
785  AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786  unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
788  if (!acceptBitWidth(posWidth))
789  return emitError() << "unexpected position bitwidth: " << posWidth;
790  if (!acceptBitWidth(crdWidth))
791  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792 
793  // Verify every COO segment.
794  auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
795  while (it != lvlTypes.end()) {
796  if (it == lvlTypes.begin() ||
798  return emitError() << "expected compressed or loose_compressed level "
799  "before singleton level";
800 
801  auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802  if (!std::all_of(it, curCOOEnd,
803  [](LevelType i) { return isSingletonLT(i); }))
804  return emitError() << "expected all singleton lvlTypes "
805  "following a singleton level";
806  // We can potentially support mixed SoA/AoS singleton levels.
807  if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
808  return it->isa<LevelPropNonDefault::SoA>() ==
810  })) {
811  return emitError() << "expected all singleton lvlTypes stored in the "
812  "same memory layout (SoA vs AoS).";
813  }
814  it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
815  }
816 
817  auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
818  if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
819  return emitError() << "Batch lvlType can only be leading levels.";
820 
821  // SoA property can only be applied on singleton level.
822  auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
823  return lt.isa<LevelPropNonDefault::SoA>();
824  });
825  if (llvm::any_of(soaLvls, [](LevelType lt) {
826  return !lt.isa<LevelFormat::Singleton>();
827  })) {
828  return emitError() << "SoA is only applicable to singleton lvlTypes.";
829  }
830 
831  // TODO: audit formats that actually are supported by backend.
832  if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
833  it != std::end(lvlTypes)) {
834  if (it != lvlTypes.end() - 1)
835  return emitError() << "expected n_out_of_m to be the last level type";
836  if (!std::all_of(lvlTypes.begin(), it,
837  [](LevelType i) { return isDenseLT(i); }))
838  return emitError() << "expected all dense lvlTypes "
839  "before a n_out_of_m level";
840  if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
841  if (!isBlockSparsity(dimToLvl)) {
842  return emitError()
843  << "expected 1xm block structure for n_out_of_m level";
844  }
845  auto sizes = getBlockSize(dimToLvl);
846  unsigned coefficient = 0;
847  for (const auto &elem : sizes) {
848  if (elem != 0) {
849  if (elem != coefficient && coefficient != 0) {
850  return emitError() << "expected only one blocked level "
851  "with the same coefficients";
852  }
853  coefficient = elem;
854  }
855  }
856  if (coefficient != getM(*it)) {
857  return emitError() << "expected coeffiencts of Affine expressions "
858  "to be equal to m of n_out_of_m level";
859  }
860  }
861  }
862  // Before we can check that the level-rank is consistent/coherent
863  // across all fields, we need to define it. The source-of-truth for
864  // the `getLvlRank` method is the length of the level-types array,
865  // since it must always be provided and have full rank; therefore we
866  // use that same source-of-truth here.
867  const Level lvlRank = lvlTypes.size();
868  if (lvlRank == 0)
869  return emitError() << "expected a non-empty array for lvlTypes";
870  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
871  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
872  if (dimToLvl) {
873  if (dimToLvl.getNumResults() != lvlRank)
874  return emitError()
875  << "level-rank mismatch between dimToLvl and lvlTypes: "
876  << dimToLvl.getNumResults() << " != " << lvlRank;
877  auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
878  // Symbols can't be inferred but are acceptable.
879  if (!inferRes && dimToLvl.getNumSymbols() == 0)
880  return emitError() << "failed to infer lvlToDim from dimToLvl";
881  if (lvlToDim && (inferRes != lvlToDim))
882  return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
883  if (dimRank > lvlRank)
884  return emitError() << "unexpected dimToLvl mapping from " << dimRank
885  << " to " << lvlRank;
886  }
887  if (!dimSlices.empty()) {
888  if (dimSlices.size() != dimRank)
889  return emitError()
890  << "dimension-rank mismatch between dimSlices and dimToLvl: "
891  << dimSlices.size() << " != " << dimRank;
892  // Compiler support for `dimSlices` currently requires that the two
893  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
894  if (dimRank != lvlRank)
895  return emitError()
896  << "dimSlices expected dimension-rank to match level-rank: "
897  << dimRank << " != " << lvlRank;
898  }
899  return success();
900 }
901 
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
903  ArrayRef<Size> dimShape, Type elementType,
905  // Check structural integrity. In particular, this ensures that the
906  // level-rank is coherent across all the fields.
907  if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908  getPosWidth(), getCrdWidth(), getExplicitVal(),
909  getImplicitVal(), getDimSlices())))
910  return failure();
911  // Check integrity with tensor type specifics. In particular, we
912  // need only check that the dimension-rank of the tensor agrees with
913  // the dimension-rank of the encoding.
914  const Dimension dimRank = dimShape.size();
915  if (dimRank == 0)
916  return emitError() << "expected non-scalar sparse tensor";
917  if (getDimRank() != dimRank)
918  return emitError()
919  << "dimension-rank mismatch between encoding and tensor shape: "
920  << getDimRank() << " != " << dimRank;
921  if (auto expVal = getExplicitVal()) {
922  Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923  if (attrType != elementType) {
924  return emitError() << "explicit value type mismatch between encoding and "
925  << "tensor element type: " << attrType
926  << " != " << elementType;
927  }
928  }
929  if (auto impVal = getImplicitVal()) {
930  Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931  if (attrType != elementType) {
932  return emitError() << "implicit value type mismatch between encoding and "
933  << "tensor element type: " << attrType
934  << " != " << elementType;
935  }
936  // Currently, we only support zero as the implicit value.
937  auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938  auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939  auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940  if ((impFVal && impFVal.getValue().isNonZero()) ||
941  (impIntVal && !impIntVal.getValue().isZero()) ||
942  (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943  impComplexVal.getReal().isNonZero()))) {
944  return emitError() << "implicit value must be zero";
945  }
946  }
947  return success();
948 }
949 
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
951  SmallVector<COOSegment> coo = getCOOSegments();
952  assert(coo.size() == 1 || coo.empty());
953  if (!coo.empty() && coo.front().isAoS()) {
954  return coo.front().lvlRange.first;
955  }
956  return getLvlRank();
957 }
958 
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
962  if (getLvlRank() <= 1)
963  return ret;
964 
965  ArrayRef<LevelType> lts = getLvlTypes();
966  Level l = 0;
967  while (l < getLvlRank()) {
968  auto lt = lts[l];
970  auto cur = lts.begin() + l;
971  auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
972  return !lt.isa<LevelFormat::Singleton>();
973  });
974  unsigned cooLen = std::distance(cur, end);
975  if (cooLen > 1) {
976  // To support mixed SoA/AoS COO, we should break the segment when the
977  // storage scheme changes, for now we faithfully assume that all
978  // consecutive singleton levels have the same storage format as verified
979  // STEA.
980  ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
981  lts[l + 1].isa<LevelPropNonDefault::SoA>()});
982  }
983  l += cooLen;
984  } else {
985  l++;
986  }
987  }
988  return ret;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // SparseTensorType Methods.
993 //===----------------------------------------------------------------------===//
994 
996  bool isUnique) const {
997  if (!hasEncoding())
998  return false;
999  if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1000  return false;
1001  for (Level l = startLvl + 1; l < lvlRank; ++l)
1002  if (!isSingletonLvl(l))
1003  return false;
1004  // If isUnique is true, then make sure that the last level is unique,
1005  // that is, when lvlRank == 1, the only compressed level is unique,
1006  // and when lvlRank > 1, the last singleton is unique.
1007  return !isUnique || isUniqueLvl(lvlRank - 1);
1008 }
1009 
1010 RankedTensorType
1012  SmallVector<LevelType> lvlTypes;
1013  lvlTypes.reserve(lvlRank);
1014  // A non-unique compressed level at beginning (unless this is
1015  // also the last level, then it is unique).
1016  lvlTypes.push_back(
1017  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
1018  if (lvlRank > 1) {
1019  // Followed by n-2 non-unique singleton levels.
1020  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1021  *buildLevelType(LevelFormat::Singleton, ordered, false));
1022  // Ends by a unique singleton level.
1023  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
1024  }
1025  auto enc = SparseTensorEncodingAttr::get(
1026  getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027  getCrdWidth(), getExplicitVal(), getImplicitVal());
1028  return RankedTensorType::get(getDimShape(), getElementType(), enc);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // Convenience Methods.
1033 //===----------------------------------------------------------------------===//
1034 
1035 SparseTensorEncodingAttr
1037  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040  return mdtp.getEncoding();
1041  return nullptr;
1042 }
1043 
1045  MLIRContext *context) {
1046  auto map = static_cast<AffineMap>(dimToLvl);
1047  AffineMap lvlToDim;
1048  // Return an empty lvlToDim when inference is not successful.
1049  if (!map || map.getNumSymbols() != 0) {
1050  lvlToDim = AffineMap();
1051  } else if (map.isPermutation()) {
1052  lvlToDim = inversePermutation(map);
1053  } else if (isBlockSparsity(map)) {
1054  lvlToDim = inverseBlockSparsity(map, context);
1055  }
1056  return lvlToDim;
1057 }
1058 
1060  MLIRContext *context) {
1061  SmallVector<AffineExpr> lvlExprs;
1062  auto numLvls = dimToLvl.getNumResults();
1063  lvlExprs.reserve(numLvls);
1064  // lvlExprComponents stores information of the floordiv and mod operations
1065  // applied to the same dimension, so as to build the lvlToDim map.
1066  std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067  for (unsigned i = 0, n = numLvls; i < n; i++) {
1068  auto result = dimToLvl.getResult(i);
1069  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1070  if (result.getKind() == AffineExprKind::FloorDiv) {
1071  // Position of the dimension in dimToLvl.
1072  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073  assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074  "expected only one floordiv for each dimension");
1075  SmallVector<AffineExpr, 3> components;
1076  // Level variable for floordiv.
1077  components.push_back(getAffineDimExpr(i, context));
1078  // Multiplier.
1079  components.push_back(binOp.getRHS());
1080  // Map key is the position of the dimension.
1081  lvlExprComponents[pos] = components;
1082  } else if (result.getKind() == AffineExprKind::Mod) {
1083  auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084  assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085  "expected floordiv before mod");
1086  // Add level variable for mod to the same vector
1087  // of the corresponding floordiv.
1088  lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
1089  } else {
1090  assert(false && "expected floordiv or mod");
1091  }
1092  } else {
1093  lvlExprs.push_back(getAffineDimExpr(i, context));
1094  }
1095  }
1096  // Build lvlExprs from lvlExprComponents.
1097  // For example, for il = i floordiv 2 and ii = i mod 2, the components
1098  // would be [il, 2, ii]. It could be used to build the AffineExpr
1099  // i = il * 2 + ii in lvlToDim.
1100  for (auto &components : lvlExprComponents) {
1101  assert(components.second.size() == 3 &&
1102  "expected 3 components to build lvlExprs");
1103  auto mulOp = getAffineBinaryOpExpr(
1104  AffineExprKind::Mul, components.second[0], components.second[1]);
1105  auto addOp =
1106  getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
1107  lvlExprs.push_back(addOp);
1108  }
1109  return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
1110 }
1111 
1113  assert(isBlockSparsity(dimToLvl) &&
1114  "expected dimToLvl to be block sparsity for calling getBlockSize");
1115  SmallVector<unsigned> blockSize;
1116  for (auto result : dimToLvl.getResults()) {
1117  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1118  if (result.getKind() == AffineExprKind::Mod) {
1119  blockSize.push_back(
1120  dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1121  }
1122  } else {
1123  blockSize.push_back(0);
1124  }
1125  }
1126  return blockSize;
1127 }
1128 
1130  if (!dimToLvl)
1131  return false;
1132  std::map<unsigned, int64_t> coeffientMap;
1133  bool hasBlock = false;
1134  for (auto result : dimToLvl.getResults()) {
1135  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1136  // Check for "dim op const".
1137  auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138  auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139  if (!dimOp || !conOp || conOp.getValue() <= 0)
1140  return false;
1141  // Inspect "dim / const" or "dim % const".
1142  auto pos = dimOp.getPosition();
1143  if (binOp.getKind() == AffineExprKind::FloorDiv) {
1144  // Expect only one floordiv for each dimension.
1145  auto [it, inserted] = coeffientMap.try_emplace(pos);
1146  if (!inserted)
1147  return false;
1148  // Record coefficient of the floordiv.
1149  it->second = conOp.getValue();
1150  } else if (binOp.getKind() == AffineExprKind::Mod) {
1151  // Expect floordiv before mod.
1152  auto it = coeffientMap.find(pos);
1153  if (it == coeffientMap.end())
1154  return false;
1155  // Expect mod to have the same coefficient as floordiv.
1156  if (conOp.getValue() != it->second)
1157  return false;
1158  hasBlock = true;
1159  } else {
1160  return false;
1161  }
1162  } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1163  auto pos = dimOp.getPosition();
1164  // Expect dim to be unset.
1165  if (!coeffientMap.try_emplace(pos, 0).second)
1166  return false;
1167  } else {
1168  return false;
1169  }
1170  }
1171  return hasBlock;
1172 }
1173 
1175  auto hasNonIdentityMap = [](Value v) {
1176  auto stt = tryGetSparseTensorType(v);
1177  return stt && !stt->isIdentity();
1178  };
1179 
1180  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
1181  llvm::any_of(op->getResults(), hasNonIdentityMap);
1182 }
1183 
1184 Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1185  if (enc) {
1186  assert(enc.isPermutation() && "Non permutation map not supported");
1187  if (const auto dimToLvl = enc.getDimToLvl())
1188  return dimToLvl.getDimPosition(l);
1189  }
1190  return l;
1191 }
1192 
1193 Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1194  if (enc) {
1195  assert(enc.isPermutation() && "Non permutation map not supported");
1196  if (const auto lvlToDim = enc.getLvlToDim())
1197  return lvlToDim.getDimPosition(d);
1198  }
1199  return d;
1200 }
1201 
1202 /// We normalized sparse tensor encoding attribute by always using
1203 /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1204 /// as other variants) lead to the same storage specifier type, and stripping
1205 /// irrelevant fields that do not alter the sparse tensor memory layout.
1206 static SparseTensorEncodingAttr
1207 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1209  for (auto lt : enc.getLvlTypes())
1210  lts.push_back(lt.stripStorageIrrelevantProperties());
1211 
1213  enc.getContext(), lts,
1214  AffineMap(), // dimToLvl (irrelevant to storage specifier)
1215  AffineMap(), // lvlToDim (irrelevant to storage specifier)
1216  // Always use `index` for memSize and lvlSize instead of reusing
1217  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1218  // value for different bitwidth, it also avoids casting between index and
1219  // integer (returned by DimOp)
1220  0, 0,
1221  Attribute(), // explicitVal (irrelevant to storage specifier)
1222  Attribute(), // implicitVal (irrelevant to storage specifier)
1223  enc.getDimSlices());
1224 }
1225 
1226 StorageSpecifierType
1227 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1228  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1229 }
1230 
1231 StorageSpecifierType
1232 StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1233  MLIRContext *ctx,
1234  SparseTensorEncodingAttr encoding) {
1235  return Base::getChecked(emitError, ctx,
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // SparseTensorDialect Operations.
1241 //===----------------------------------------------------------------------===//
1242 
1243 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1244  return success(lvl < getSparseTensorType(tensor).getLvlRank());
1245 }
1246 
1247 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1248  const Type etp = getMemRefType(mem).getElementType();
1249  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1250 }
1251 
1252 static LogicalResult verifySparsifierGetterSetter(
1253  StorageSpecifierKind mdKind, std::optional<Level> lvl,
1255  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1256  return op->emitError(
1257  "redundant level argument for querying value memory size");
1258  }
1259 
1260  const auto enc = md.getType().getEncoding();
1261  const Level lvlRank = enc.getLvlRank();
1262 
1263  if (mdKind == StorageSpecifierKind::DimOffset ||
1264  mdKind == StorageSpecifierKind::DimStride)
1265  if (!enc.isSlice())
1266  return op->emitError("requested slice data on non-slice tensor");
1267 
1268  if (mdKind != StorageSpecifierKind::ValMemSize) {
1269  if (!lvl)
1270  return op->emitError("missing level argument");
1271 
1272  const Level l = lvl.value();
1273  if (l >= lvlRank)
1274  return op->emitError("requested level is out of bounds");
1275 
1276  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1277  return op->emitError(
1278  "requested position memory size on a singleton level");
1279  }
1280  return success();
1281 }
1282 
1284  switch (kind) {
1286  return stt.getCrdType();
1288  return stt.getPosType();
1290  return stt.getElementType();
1292  return nullptr;
1293  }
1294  llvm_unreachable("Unrecognizable FieldKind");
1295 }
1296 
1297 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1298  SparseTensorType stt,
1299  RankedTensorType valTp,
1300  TypeRange lvlTps) {
1301  if (requiresStaticShape && !stt.hasStaticDimShape())
1302  return op->emitError("the sparse-tensor must have static shape");
1303  if (!stt.hasEncoding())
1304  return op->emitError("the sparse-tensor must have an encoding attribute");
1305 
1306  // Verifies the trailing COO.
1307  Level cooStartLvl = stt.getAoSCOOStart();
1308  if (cooStartLvl < stt.getLvlRank()) {
1309  // We only supports trailing COO for now, must be the last input.
1310  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1311  // The coordinates should be in shape of <? x rank>
1312  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1313  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1314  return op->emitError("input/output trailing COO level-ranks don't match");
1315  }
1316  }
1317 
1318  // Verifies that all types match.
1319  StorageLayout layout(stt.getEncoding());
1320  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1321  return op->emitError("inconsistent number of fields between input/output");
1322 
1323  unsigned idx = 0;
1324  bool misMatch = false;
1325  layout.foreachField([&idx, &misMatch, stt, valTp,
1326  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1327  Level lvl, LevelType lt) -> bool {
1329  return true;
1330 
1331  Type inputTp = nullptr;
1332  if (fKind == SparseTensorFieldKind::ValMemRef) {
1333  inputTp = valTp;
1334  } else {
1335  assert(fid == idx && stt.getLvlType(lvl) == lt);
1336  inputTp = lvlTps[idx++];
1337  }
1338  // The input element type and expected element type should match.
1339  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1340  Type expElemTp = getFieldElemType(stt, fKind);
1341  if (inpElemTp != expElemTp) {
1342  misMatch = true;
1343  return false; // to terminate the iteration
1344  }
1345  return true;
1346  });
1347 
1348  if (misMatch)
1349  return op->emitError("input/output element-types don't match");
1350  return success();
1351 }
1352 
1353 LogicalResult AssembleOp::verify() {
1354  RankedTensorType valuesTp = getValues().getType();
1355  const auto lvlsTp = getLevels().getTypes();
1356  const auto resTp = getSparseTensorType(getResult());
1357  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1358 }
1359 
1360 LogicalResult DisassembleOp::verify() {
1361  if (getOutValues().getType() != getRetValues().getType())
1362  return emitError("output values and return value type mismatch");
1363 
1364  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1365  if (ot.getType() != rt.getType())
1366  return emitError("output levels and return levels type mismatch");
1367 
1368  RankedTensorType valuesTp = getRetValues().getType();
1369  const auto lvlsTp = getRetLevels().getTypes();
1370  const auto srcTp = getSparseTensorType(getTensor());
1371  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1372 }
1373 
1374 LogicalResult ConvertOp::verify() {
1375  RankedTensorType tp1 = getSource().getType();
1376  RankedTensorType tp2 = getDest().getType();
1377  if (tp1.getRank() != tp2.getRank())
1378  return emitError("unexpected conversion mismatch in rank");
1379  auto dstEnc =
1380  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1381  if (dstEnc && dstEnc.isSlice())
1382  return emitError("cannot convert to a sparse tensor slice");
1383 
1384  auto shape1 = tp1.getShape();
1385  auto shape2 = tp2.getShape();
1386  // Accept size matches between the source and the destination type
1387  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1388  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1389  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1390  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1391  return emitError("unexpected conversion mismatch in dimension ") << d;
1392  return success();
1393 }
1394 
1395 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1396  if (getType() == getSource().getType())
1397  return getSource();
1398  return {};
1399 }
1400 
1401 bool ConvertOp::needsExtraSort() {
1402  SparseTensorType srcStt = getSparseTensorType(getSource());
1403  SparseTensorType dstStt = getSparseTensorType(getDest());
1404 
1405  // We do not need an extra sort when returning unordered sparse tensors or
1406  // dense tensor since dense tensor support random access.
1407  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1408  return false;
1409 
1410  if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1411  srcStt.hasSameDimToLvl(dstStt)) {
1412  return false;
1413  }
1414 
1415  // Source and dest tensors are ordered in different ways. We only do direct
1416  // dense to sparse conversion when the dense input is defined by a sparse
1417  // constant. Note that we can theoretically always directly convert from dense
1418  // inputs by rotating dense loops but it leads to bad cache locality and hurt
1419  // performance.
1420  if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1421  if (isa<SparseElementsAttr>(constOp.getValue()))
1422  return false;
1423 
1424  return true;
1425 }
1426 
1427 LogicalResult CrdTranslateOp::verify() {
1428  uint64_t inRank = getEncoder().getLvlRank();
1429  uint64_t outRank = getEncoder().getDimRank();
1430 
1431  if (getDirection() == CrdTransDirectionKind::dim2lvl)
1432  std::swap(inRank, outRank);
1433 
1434  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1435  return emitError("Coordinate rank mismatch with encoding");
1436 
1437  return success();
1438 }
1439 
1440 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1441  SmallVectorImpl<OpFoldResult> &results) {
1442  if (getEncoder().isIdentity()) {
1443  results.assign(getInCrds().begin(), getInCrds().end());
1444  return success();
1445  }
1446  if (getEncoder().isPermutation()) {
1447  AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1448  ? getEncoder().getDimToLvl()
1449  : getEncoder().getLvlToDim();
1450  for (AffineExpr exp : perm.getResults())
1451  results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1452  return success();
1453  }
1454 
1455  // Fuse dim2lvl/lvl2dim pairs.
1456  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1457  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1458  return v.getDefiningOp() == def;
1459  });
1460  if (!sameDef)
1461  return failure();
1462 
1463  bool oppositeDir = def.getDirection() != getDirection();
1464  bool sameOracle =
1465  def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1466  bool sameCount = def.getNumResults() == getInCrds().size();
1467  if (!oppositeDir || !sameOracle || !sameCount)
1468  return failure();
1469 
1470  // The definition produces the coordinates in the same order as the input
1471  // coordinates.
1472  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1473  [](auto valuePair) {
1474  auto [lhs, rhs] = valuePair;
1475  return lhs == rhs;
1476  });
1477 
1478  if (!sameOrder)
1479  return failure();
1480  // l1 = dim2lvl (lvl2dim l0)
1481  // ==> l0
1482  results.append(def.getInCrds().begin(), def.getInCrds().end());
1483  return success();
1484 }
1485 
1486 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1487  int64_t index) {
1488  Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1489  return build(builder, state, source, val);
1490 }
1491 
1492 LogicalResult LvlOp::verify() {
1493  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1494  auto stt = getSparseTensorType(getSource());
1495  if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1496  return emitError(
1497  "Level index exceeds the rank of the input sparse tensor");
1498  }
1499  return success();
1500 }
1501 
1502 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1503  return getConstantIntValue(getIndex());
1504 }
1505 
1506 Speculation::Speculatability LvlOp::getSpeculatability() {
1507  auto constantIndex = getConstantLvlIndex();
1508  if (!constantIndex)
1510 
1511  assert(constantIndex <
1512  cast<RankedTensorType>(getSource().getType()).getRank());
1514 }
1515 
1516 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1517  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1518  if (!lvlIndex)
1519  return {};
1520 
1521  Level lvl = lvlIndex.getAPSInt().getZExtValue();
1522  auto stt = getSparseTensorType(getSource());
1523  if (lvl >= stt.getLvlRank()) {
1524  // Follows the same convention used by tensor.dim operation. Out of bound
1525  // indices produce undefined behavior but are still valid IR. Don't choke on
1526  // them.
1527  return {};
1528  }
1529 
1530  // Helper lambda to build an IndexAttr.
1531  auto getIndexAttr = [this](int64_t lvlSz) {
1532  return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1533  };
1534 
1535  SmallVector<Size> lvlShape = stt.getLvlShape();
1536  if (!ShapedType::isDynamic(lvlShape[lvl]))
1537  return getIndexAttr(lvlShape[lvl]);
1538 
1539  return {};
1540 }
1541 
1542 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1543  SparseTensorEncodingAttr dstEnc, Value source) {
1544  auto srcStt = getSparseTensorType(source);
1545  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1546  SmallVector<int64_t> dstDimShape =
1547  dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1548  auto dstTp =
1549  RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1550  return build(odsBuilder, odsState, dstTp, source);
1551 }
1552 
1553 LogicalResult ReinterpretMapOp::verify() {
1554  auto srcStt = getSparseTensorType(getSource());
1555  auto dstStt = getSparseTensorType(getDest());
1556  ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1557  ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1558 
1559  if (srcLvlTps.size() != dstLvlTps.size())
1560  return emitError("Level rank mismatch between source/dest tensors");
1561 
1562  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1563  if (srcLvlTp != dstLvlTp)
1564  return emitError("Level type mismatch between source/dest tensors");
1565 
1566  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1567  srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1568  return emitError("Crd/Pos width mismatch between source/dest tensors");
1569  }
1570 
1571  if (srcStt.getElementType() != dstStt.getElementType())
1572  return emitError("Element type mismatch between source/dest tensors");
1573 
1574  SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1575  SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1576  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1577  if (srcLvlSz != dstLvlSz) {
1578  // Should we allow one side to be dynamic size, e.g., <?x?> should be
1579  // compatible to <3x4>? For now, we require all the level sizes to be
1580  // *exactly* matched for simplicity.
1581  return emitError("Level size mismatch between source/dest tensors");
1582  }
1583  }
1584 
1585  return success();
1586 }
1587 
1588 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1589  if (getSource().getType() == getDest().getType())
1590  return getSource();
1591 
1592  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1593  // A -> B, B -> A ==> A
1594  if (def.getSource().getType() == getDest().getType())
1595  return def.getSource();
1596  }
1597  return {};
1598 }
1599 
1600 template <typename ToBufferOp>
1601 static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1602  OpaqueProperties prop,
1603  RegionRange region,
1605  typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1606  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1607  Type elemTp = nullptr;
1608  bool withStride = false;
1609  if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1610  elemTp = stt.getPosType();
1611  } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1612  std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1613  elemTp = stt.getCrdType();
1614  if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1615  withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1616  } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1617  elemTp = stt.getElementType();
1618  }
1619 
1620  assert(elemTp && "unhandled operation.");
1621  SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1622  bufShape.push_back(ShapedType::kDynamic);
1623 
1624  auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1625  stt.getContext(), ShapedType::kDynamic,
1626  {ShapedType::kDynamic})
1627  : StridedLayoutAttr();
1628  ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1629  return success();
1630 }
1631 
1632 LogicalResult ToPositionsOp::verify() {
1633  auto stt = getSparseTensorType(getTensor());
1634  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1635  return emitError("requested level is out of bounds");
1636  if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1637  return emitError("unexpected type for positions");
1638  return success();
1639 }
1640 
1641 LogicalResult
1642 ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1643  ValueRange ops, DictionaryAttr attr,
1644  OpaqueProperties prop, RegionRange region,
1646  return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1647 }
1648 
1649 LogicalResult ToCoordinatesOp::verify() {
1650  auto stt = getSparseTensorType(getTensor());
1651  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1652  return emitError("requested level is out of bounds");
1653  if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1654  return emitError("unexpected type for coordinates");
1655  return success();
1656 }
1657 
1658 LogicalResult
1659 ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1660  ValueRange ops, DictionaryAttr attr,
1661  OpaqueProperties prop, RegionRange region,
1663  return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1664 }
1665 
1666 LogicalResult ToCoordinatesBufferOp::verify() {
1667  auto stt = getSparseTensorType(getTensor());
1668  if (stt.getAoSCOOStart() >= stt.getLvlRank())
1669  return emitError("expected sparse tensor with a COO region");
1670  return success();
1671 }
1672 
1673 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1674  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1675  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1677  return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1678  ret);
1679 }
1680 
1681 LogicalResult ToValuesOp::verify() {
1682  auto stt = getSparseTensorType(getTensor());
1683  auto mtp = getMemRefType(getResult());
1684  if (stt.getElementType() != mtp.getElementType())
1685  return emitError("unexpected mismatch in element types");
1686  return success();
1687 }
1688 
1689 LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1690  std::optional<Location> loc,
1691  ValueRange ops, DictionaryAttr attr,
1692  OpaqueProperties prop,
1693  RegionRange region,
1695  return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1696 }
1697 
1698 LogicalResult ToSliceOffsetOp::verify() {
1699  auto rank = getSlice().getType().getRank();
1700  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1701  return emitError("requested dimension out of bound");
1702  return success();
1703 }
1704 
1705 LogicalResult ToSliceStrideOp::verify() {
1706  auto rank = getSlice().getType().getRank();
1707  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1708  return emitError("requested dimension out of bound");
1709  return success();
1710 }
1711 
1712 LogicalResult GetStorageSpecifierOp::verify() {
1713  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1714  getSpecifier(), getOperation());
1715 }
1716 
1717 template <typename SpecifierOp>
1718 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1719  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1720 }
1721 
1722 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1723  const StorageSpecifierKind kind = getSpecifierKind();
1724  const auto lvl = getLevel();
1725  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1726  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1727  return op.getValue();
1728  return {};
1729 }
1730 
1731 LogicalResult SetStorageSpecifierOp::verify() {
1732  return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1733  getSpecifier(), getOperation());
1734 }
1735 
1736 template <class T>
1737 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1738  const char *regionName,
1739  TypeRange inputTypes, Type outputType) {
1740  unsigned numArgs = region.getNumArguments();
1741  unsigned expectedNum = inputTypes.size();
1742  if (numArgs != expectedNum)
1743  return op->emitError() << regionName << " region must have exactly "
1744  << expectedNum << " arguments";
1745 
1746  for (unsigned i = 0; i < numArgs; i++) {
1747  Type typ = region.getArgument(i).getType();
1748  if (typ != inputTypes[i])
1749  return op->emitError() << regionName << " region argument " << (i + 1)
1750  << " type mismatch";
1751  }
1752  Operation *term = region.front().getTerminator();
1753  YieldOp yield = dyn_cast<YieldOp>(term);
1754  if (!yield)
1755  return op->emitError() << regionName
1756  << " region must end with sparse_tensor.yield";
1757  if (!yield.hasSingleResult() ||
1758  yield.getSingleResult().getType() != outputType)
1759  return op->emitError() << regionName << " region yield type mismatch";
1760 
1761  return success();
1762 }
1763 
1764 LogicalResult BinaryOp::verify() {
1765  NamedAttrList attrs = (*this)->getAttrs();
1766  Type leftType = getX().getType();
1767  Type rightType = getY().getType();
1768  Type outputType = getOutput().getType();
1769  Region &overlap = getOverlapRegion();
1770  Region &left = getLeftRegion();
1771  Region &right = getRightRegion();
1772 
1773  // Check correct number of block arguments and return type for each
1774  // non-empty region.
1775  if (!overlap.empty()) {
1776  if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1777  TypeRange{leftType, rightType}, outputType)))
1778  return failure();
1779  }
1780  if (!left.empty()) {
1781  if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1782  outputType)))
1783  return failure();
1784  } else if (getLeftIdentity()) {
1785  if (leftType != outputType)
1786  return emitError("left=identity requires first argument to have the same "
1787  "type as the output");
1788  }
1789  if (!right.empty()) {
1790  if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1791  outputType)))
1792  return failure();
1793  } else if (getRightIdentity()) {
1794  if (rightType != outputType)
1795  return emitError("right=identity requires second argument to have the "
1796  "same type as the output");
1797  }
1798  return success();
1799 }
1800 
1801 LogicalResult UnaryOp::verify() {
1802  Type inputType = getX().getType();
1803  Type outputType = getOutput().getType();
1804 
1805  // Check correct number of block arguments and return type for each
1806  // non-empty region.
1807  Region &present = getPresentRegion();
1808  if (!present.empty()) {
1809  if (failed(verifyNumBlockArgs(this, present, "present",
1810  TypeRange{inputType}, outputType)))
1811  return failure();
1812  }
1813  Region &absent = getAbsentRegion();
1814  if (!absent.empty()) {
1815  if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1816  outputType)))
1817  return failure();
1818  // Absent branch can only yield invariant values.
1819  Block *absentBlock = &absent.front();
1820  Block *parent = getOperation()->getBlock();
1821  Value absentVal =
1822  cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1823  if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1824  if (arg.getOwner() == parent)
1825  return emitError("absent region cannot yield linalg argument");
1826  } else if (Operation *def = absentVal.getDefiningOp()) {
1827  if (!isa<arith::ConstantOp>(def) &&
1828  (def->getBlock() == absentBlock || def->getBlock() == parent))
1829  return emitError("absent region cannot yield locally computed value");
1830  }
1831  }
1832  return success();
1833 }
1834 
1835 bool ConcatenateOp::needsExtraSort() {
1836  SparseTensorType dstStt = getSparseTensorType(*this);
1837  if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1838  return false;
1839 
1840  bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1841  return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1842  });
1843  // TODO: When conDim != 0, as long as conDim corresponding to the first level
1844  // in all input/output buffers, and all input/output buffers have the same
1845  // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1846  // CSC matrices along column).
1847  bool directLowerable =
1848  allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1849  return !directLowerable;
1850 }
1851 
1852 LogicalResult ConcatenateOp::verify() {
1853  const auto dstTp = getSparseTensorType(*this);
1854  const Dimension concatDim = getDimension();
1855  const Dimension dimRank = dstTp.getDimRank();
1856 
1857  if (getInputs().size() <= 1)
1858  return emitError("Need at least two tensors to concatenate.");
1859 
1860  if (concatDim >= dimRank)
1861  return emitError(llvm::formatv(
1862  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1863  concatDim, dimRank));
1864 
1865  for (const auto &it : llvm::enumerate(getInputs())) {
1866  const auto i = it.index();
1867  const auto srcTp = getSparseTensorType(it.value());
1868  if (srcTp.hasDynamicDimShape())
1869  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1870  const Dimension srcDimRank = srcTp.getDimRank();
1871  if (srcDimRank != dimRank)
1872  return emitError(
1873  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1874  "from the output tensor (rank={2}).",
1875  i, srcDimRank, dimRank));
1876  }
1877 
1878  for (Dimension d = 0; d < dimRank; d++) {
1879  const Size dstSh = dstTp.getDimShape()[d];
1880  if (d == concatDim) {
1881  if (!ShapedType::isDynamic(dstSh)) {
1882  // If we reach here, then all inputs have static shapes. So we
1883  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1884  // to avoid redundant assertions in the loop.
1885  Size sumSz = 0;
1886  for (const auto src : getInputs())
1887  sumSz += getSparseTensorType(src).getDimShape()[d];
1888  // If all dimension are statically known, the sum of all the input
1889  // dimensions should be equal to the output dimension.
1890  if (sumSz != dstSh)
1891  return emitError(
1892  "The concatenation dimension of the output tensor should be the "
1893  "sum of all the concatenation dimensions of the input tensors.");
1894  }
1895  } else {
1896  Size prev = dstSh;
1897  for (const auto src : getInputs()) {
1898  const auto sh = getSparseTensorType(src).getDimShape()[d];
1899  if (!ShapedType::isDynamic(prev) && sh != prev)
1900  return emitError("All dimensions (expect for the concatenating one) "
1901  "should be equal.");
1902  prev = sh;
1903  }
1904  }
1905  }
1906 
1907  return success();
1908 }
1909 
1910 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1911  Value curSize, Value inBuffer, Value value) {
1912  build(builder, result, curSize, inBuffer, value, Value());
1913 }
1914 
1915 LogicalResult PushBackOp::verify() {
1916  if (Value n = getN()) {
1917  std::optional<int64_t> nValue = getConstantIntValue(n);
1918  if (nValue && nValue.value() < 1)
1919  return emitOpError("n must be not less than 1");
1920  }
1921  return success();
1922 }
1923 
1924 LogicalResult CompressOp::verify() {
1925  const auto stt = getSparseTensorType(getTensor());
1926  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1927  return emitOpError("incorrect number of coordinates");
1928  return success();
1929 }
1930 
1931 void ForeachOp::build(
1932  OpBuilder &builder, OperationState &result, Value tensor,
1933  ValueRange initArgs, AffineMapAttr order,
1935  bodyBuilder) {
1936  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1937  // Builds foreach body.
1938  if (!bodyBuilder)
1939  return;
1940  const auto stt = getSparseTensorType(tensor);
1941  const Dimension dimRank = stt.getDimRank();
1942 
1943  // Starts with `dimRank`-many coordinates.
1944  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1945  // Followed by one value.
1946  blockArgTypes.push_back(stt.getElementType());
1947  // Followed by the reduction variables.
1948  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1949 
1950  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1951 
1952  OpBuilder::InsertionGuard guard(builder);
1953  auto &region = *result.regions.front();
1954  Block *bodyBlock =
1955  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1956  bodyBuilder(builder, result.location,
1957  bodyBlock->getArguments().slice(0, dimRank),
1958  bodyBlock->getArguments()[dimRank],
1959  bodyBlock->getArguments().drop_front(dimRank + 1));
1960 }
1961 
1962 LogicalResult ForeachOp::verify() {
1963  const auto t = getSparseTensorType(getTensor());
1964  const Dimension dimRank = t.getDimRank();
1965  const auto args = getBody()->getArguments();
1966 
1967  if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1968  return emitError("Level traverse order does not match tensor's level rank");
1969 
1970  if (dimRank + 1 + getInitArgs().size() != args.size())
1971  return emitError("Unmatched number of arguments in the block");
1972 
1973  if (getNumResults() != getInitArgs().size())
1974  return emitError("Mismatch in number of init arguments and results");
1975 
1976  if (getResultTypes() != getInitArgs().getTypes())
1977  return emitError("Mismatch in types of init arguments and results");
1978 
1979  // Cannot mark this const, because the getters aren't.
1980  auto yield = cast<YieldOp>(getBody()->getTerminator());
1981  if (yield.getNumOperands() != getNumResults() ||
1982  yield.getOperands().getTypes() != getResultTypes())
1983  return emitError("Mismatch in types of yield values and results");
1984 
1985  const auto iTp = IndexType::get(getContext());
1986  for (Dimension d = 0; d < dimRank; d++)
1987  if (args[d].getType() != iTp)
1988  return emitError(
1989  llvm::formatv("Expecting Index type for argument at index {0}", d));
1990 
1991  const auto elemTp = t.getElementType();
1992  const auto valueTp = args[dimRank].getType();
1993  if (elemTp != valueTp)
1994  return emitError(
1995  llvm::formatv("Unmatched element type between input tensor and "
1996  "block argument, expected:{0}, got: {1}",
1997  elemTp, valueTp));
1998  return success();
1999 }
2000 
2001 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2002  if (getSparseTensorEncoding(getInputCoo().getType()) ==
2003  getSparseTensorEncoding(getResultCoo().getType()))
2004  return getInputCoo();
2005 
2006  return {};
2007 }
2008 
2009 LogicalResult ReorderCOOOp::verify() {
2010  SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2011  SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2012 
2013  if (!srcStt.isCOOType() || !dstStt.isCOOType())
2014  return emitError("Expected COO sparse tensors only");
2015 
2016  if (!srcStt.hasSameDimToLvl(dstStt))
2017  return emitError("Unmatched dim2lvl map between input and result COO");
2018 
2019  if (srcStt.getPosType() != dstStt.getPosType() ||
2020  srcStt.getCrdType() != dstStt.getCrdType() ||
2021  srcStt.getElementType() != dstStt.getElementType())
2022  return emitError("Unmatched storage format between input and result COO");
2023 
2024  return success();
2025 }
2026 
2027 LogicalResult ReduceOp::verify() {
2028  Type inputType = getX().getType();
2029  Region &formula = getRegion();
2030  return verifyNumBlockArgs(this, formula, "reduce",
2031  TypeRange{inputType, inputType}, inputType);
2032 }
2033 
2034 LogicalResult SelectOp::verify() {
2035  Builder b(getContext());
2036  Type inputType = getX().getType();
2037  Type boolType = b.getI1Type();
2038  Region &formula = getRegion();
2039  return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2040  boolType);
2041 }
2042 
2043 LogicalResult SortOp::verify() {
2044  AffineMap xPerm = getPermMap();
2045  uint64_t nx = xPerm.getNumDims();
2046  if (nx < 1)
2047  return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2048 
2049  if (!xPerm.isPermutation())
2050  return emitError(
2051  llvm::formatv("Expected a permutation map, got {0}", xPerm));
2052 
2053  // We can't check the size of the buffers when n or buffer dimensions aren't
2054  // compile-time constants.
2055  std::optional<int64_t> cn = getConstantIntValue(getN());
2056  if (!cn)
2057  return success();
2058 
2059  // Verify dimensions.
2060  const auto checkDim = [&](Value v, Size minSize,
2061  const char *message) -> LogicalResult {
2062  const Size sh = getMemRefType(v).getShape()[0];
2063  if (!ShapedType::isDynamic(sh) && sh < minSize)
2064  return emitError(
2065  llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2066  return success();
2067  };
2068  uint64_t n = cn.value();
2069  uint64_t ny = 0;
2070  if (auto nyAttr = getNyAttr())
2071  ny = nyAttr.getInt();
2072  if (failed(checkDim(getXy(), n * (nx + ny),
2073  "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2074  return failure();
2075  for (Value opnd : getYs())
2076  if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2077  return failure();
2078 
2079  return success();
2080 }
2081 
2082 //===----------------------------------------------------------------------===//
2083 // Sparse Tensor Iteration Operations.
2084 //===----------------------------------------------------------------------===//
2085 
2086 IterSpaceType IteratorType::getIterSpaceType() const {
2087  return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2088  getHiLvl());
2089 }
2090 
2091 IteratorType IterSpaceType::getIteratorType() const {
2092  return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2093 }
2094 
2095 /// Parses a level range in the form "$lo `to` $hi"
2096 /// or simply "$lo" if $hi - $lo = 1
2097 static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2098  Level &lvlHi) {
2099  if (parser.parseInteger(lvlLo))
2100  return failure();
2101 
2102  if (succeeded(parser.parseOptionalKeyword("to"))) {
2103  if (parser.parseInteger(lvlHi))
2104  return failure();
2105  } else {
2106  lvlHi = lvlLo + 1;
2107  }
2108 
2109  if (lvlHi <= lvlLo)
2110  return parser.emitError(parser.getNameLoc(),
2111  "expect larger level upper bound than lower bound");
2112 
2113  return success();
2114 }
2115 
2116 /// Parses a level range in the form "$lo `to` $hi"
2117 /// or simply "$lo" if $hi - $lo = 1
2118 static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2119  IntegerAttr &lvlHiAttr) {
2120  Level lvlLo, lvlHi;
2121  if (parseLevelRange(parser, lvlLo, lvlHi))
2122  return failure();
2123 
2124  lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2125  lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2126  return success();
2127 }
2128 
2129 /// Prints a level range in the form "$lo `to` $hi"
2130 /// or simply "$lo" if $hi - $lo = 1
2131 static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2132 
2133  if (lo + 1 == hi)
2134  p << lo;
2135  else
2136  p << lo << " to " << hi;
2137 }
2138 
2139 /// Prints a level range in the form "$lo `to` $hi"
2140 /// or simply "$lo" if $hi - $lo = 1
2141 static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2142  IntegerAttr lvlHi) {
2143  unsigned lo = lvlLo.getValue().getZExtValue();
2144  unsigned hi = lvlHi.getValue().getZExtValue();
2145  printLevelRange(p, lo, hi);
2146 }
2147 
2148 /// Parses a list of `optional` defined list in the form of
2149 /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2150 /// corresponding value is not defined (e.g., to represent an undefined
2151 /// coordinate in the sparse iteration space).
2152 static ParseResult parseOptionalDefinedList(
2153  OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2155  unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2157  unsigned cnt = 0;
2158  ParseResult crdList =
2159  parser.parseCommaSeparatedList(delimiter, [&]() -> ParseResult {
2160  if (parser.parseOptionalKeyword("_")) {
2161  if (parser.parseArgument(definedArgs.emplace_back()))
2162  return failure();
2163  definedSet.set(cnt);
2164  }
2165  cnt += 1;
2166  return success();
2167  });
2168 
2169  if (cnt > maxCnt)
2170  return parser.emitError(parser.getNameLoc(),
2171  "parsed more value than expected.");
2172 
2173  if (failed(crdList)) {
2174  return parser.emitError(
2175  parser.getNameLoc(),
2176  "expecting SSA value or \"_\" for level coordinates");
2177  }
2178  assert(definedArgs.size() == definedSet.count());
2179  return success();
2180 }
2181 
2182 static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2183  Block::BlockArgListType blocksArgs,
2184  I64BitSet definedSet) {
2185  if (definedSet.empty())
2186  return;
2187 
2188  for (unsigned i = 0; i < size; i++) {
2189  if (definedSet[i]) {
2190  p << blocksArgs.front();
2191  blocksArgs = blocksArgs.drop_front();
2192  } else {
2193  p << "_";
2194  }
2195  if (i != size - 1)
2196  p << ", ";
2197  }
2198  assert(blocksArgs.empty());
2199 }
2200 
2201 static ParseResult
2204  // Parse "at(%crd0, _, ...)"
2205  I64BitSet crdUsedLvlSet;
2206  if (succeeded(parser.parseOptionalKeyword("at")) &&
2207  failed(parseOptionalDefinedList(parser, state, crdUsedLvlSet, coords)))
2208  return failure();
2209 
2210  // Always use IndexType for the coordinate.
2211  for (auto &coord : coords)
2212  coord.type = parser.getBuilder().getIndexType();
2213 
2214  // Set the CrdUsedLvl bitset.
2215  state.addAttribute("crdUsedLvls",
2216  parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2217  return success();
2218 }
2219 
2220 static ParseResult
2226 
2227  // Parse "%iters, ... in %spaces, ..."
2228  if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
2229  parser.parseOperandList(spaces))
2230  return failure();
2231 
2232  if (iterators.size() != spaces.size())
2233  return parser.emitError(
2234  parser.getNameLoc(),
2235  "mismatch in number of sparse iterators and sparse spaces");
2236 
2238  if (failed(parseUsedCoordList(parser, state, coords)))
2239  return failure();
2240  size_t numCrds = coords.size();
2241 
2242  // Parse "iter_args(%arg = %init, ...)"
2243  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2244  if (hasIterArgs)
2245  if (parser.parseAssignmentList(blockArgs, initArgs))
2246  return failure();
2247 
2248  blockArgs.append(coords);
2249 
2250  SmallVector<Type> iterSpaceTps;
2251  // parse ": sparse_tensor.iter_space -> ret"
2252  if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
2253  return failure();
2254  if (iterSpaceTps.size() != spaces.size())
2255  return parser.emitError(parser.getNameLoc(),
2256  "mismatch in number of iteration space operands "
2257  "and iteration space types");
2258 
2259  for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2260  IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2261  if (!spaceTp)
2262  return parser.emitError(parser.getNameLoc(),
2263  "expected sparse_tensor.iter_space type for "
2264  "iteration space operands");
2265  it.type = spaceTp.getIteratorType();
2266  }
2267 
2268  if (hasIterArgs)
2269  if (parser.parseArrowTypeList(state.types))
2270  return failure();
2271 
2272  // Resolves input operands.
2273  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2274  state.operands))
2275  return failure();
2276 
2277  if (hasIterArgs) {
2278  // Strip off leading args that used for coordinates.
2279  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2280  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2281  return parser.emitError(
2282  parser.getNameLoc(),
2283  "mismatch in number of iteration arguments and return values");
2284  }
2285 
2286  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2287  it.type = tp;
2288  if (parser.resolveOperand(init, tp, state.operands))
2289  return failure();
2290  }
2291  }
2292  return success();
2293 }
2294 
2295 static ParseResult
2297  SmallVectorImpl<Value> &spacesVals,
2299 
2300  // Parse "(%spaces, ...)"
2302  if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
2303  return failure();
2304 
2306  if (failed(parseUsedCoordList(parser, state, coords)))
2307  return failure();
2308  size_t numCrds = coords.size();
2309 
2310  // Parse "iter_args(%arg = %init, ...)"
2312  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
2313  if (hasIterArgs)
2314  if (parser.parseAssignmentList(blockArgs, initArgs))
2315  return failure();
2316  blockArgs.append(coords);
2317 
2318  SmallVector<Type> iterSpaceTps;
2319  // parse ": (sparse_tensor.iter_space, ...) -> ret"
2320  if (parser.parseColon() || parser.parseLParen() ||
2321  parser.parseTypeList(iterSpaceTps) || parser.parseRParen())
2322  return failure();
2323 
2324  if (iterSpaceTps.size() != spaces.size())
2325  return parser.emitError(parser.getNameLoc(),
2326  "mismatch in number of iteration space operands "
2327  "and iteration space types");
2328 
2329  if (hasIterArgs)
2330  if (parser.parseArrowTypeList(state.types))
2331  return failure();
2332 
2333  // Resolves input sparse iteration spaces.
2334  if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
2335  spacesVals))
2336  return failure();
2337  state.operands.append(spacesVals);
2338 
2339  if (hasIterArgs) {
2340  // Strip off trailing args that used for coordinates.
2341  MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
2342  if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2343  return parser.emitError(
2344  parser.getNameLoc(),
2345  "mismatch in number of iteration arguments and return values");
2346  }
2347 
2348  for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2349  it.type = tp;
2350  if (parser.resolveOperand(init, tp, state.operands))
2351  return failure();
2352  }
2353  }
2354  return success();
2355 }
2356 
2357 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2358  MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2359  DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2361 
2362  ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2363  SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2364  ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2365  adaptor.getHiLvl()));
2366  return success();
2367 }
2368 
2369 LogicalResult ExtractIterSpaceOp::verify() {
2370  if (getLoLvl() >= getHiLvl())
2371  return emitOpError("expected smaller level low than level high");
2372 
2373  TypedValue<IteratorType> pIter = getParentIter();
2374  if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2375  return emitOpError(
2376  "parent iterator should be specified iff level lower bound equals 0");
2377  }
2378 
2379  if (pIter) {
2380  IterSpaceType spaceTp = getExtractedSpace().getType();
2381  if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2382  return emitOpError(
2383  "mismatch in parent iterator encoding and iteration space encoding.");
2384 
2385  if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2386  return emitOpError("parent iterator should be used to extract an "
2387  "iteration space from a consecutive level.");
2388  }
2389 
2390  return success();
2391 }
2392 
2393 LogicalResult ExtractValOp::verify() {
2394  auto stt = getSparseTensorType(getTensor());
2395  auto itTp = getIterator().getType();
2396 
2397  if (stt.getEncoding() != itTp.getEncoding())
2398  return emitOpError("mismatch in tensor encoding and iterator encoding.");
2399 
2400  if (stt.getLvlRank() != itTp.getHiLvl())
2401  return emitOpError("must use last-level iterator to extract values. ");
2402 
2403  return success();
2404 }
2405 
2406 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2408 
2409  LogicalResult matchAndRewrite(IterateOp iterateOp,
2410  PatternRewriter &rewriter) const override {
2411  I64BitSet newUsedLvls(0);
2412  llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2413  for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2414  if (auto crd = iterateOp.getLvlCrd(i)) {
2415  if (crd->getUsers().empty())
2416  toRemove.set(crd->getArgNumber());
2417  else
2418  newUsedLvls.set(i);
2419  }
2420  }
2421 
2422  // All coordinates are used.
2423  if (toRemove.none())
2424  return failure();
2425 
2426  rewriter.startOpModification(iterateOp);
2427  iterateOp.setCrdUsedLvls(newUsedLvls);
2428  iterateOp.getBody()->eraseArguments(toRemove);
2429  rewriter.finalizeOpModification(iterateOp);
2430  return success();
2431  }
2432 };
2433 
2434 void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2435  mlir::MLIRContext *context) {
2436  results.add<RemoveUnusedLvlCrds>(context);
2437 }
2438 
2439 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2440  Value iterSpace, ValueRange initArgs) {
2441  unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2442  // All ones.
2443  I64BitSet set((1 << rank) - 1);
2444  return build(builder, odsState, iterSpace, initArgs, set);
2445 }
2446 
2447 void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2448  Value iterSpace, ValueRange initArgs,
2449  I64BitSet crdUsedLvls) {
2450  OpBuilder::InsertionGuard guard(builder);
2451 
2452  odsState.addOperands(iterSpace);
2453  odsState.addOperands(initArgs);
2454  odsState.getOrAddProperties<Properties>().crdUsedLvls =
2455  builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2456  Region *bodyRegion = odsState.addRegion();
2457  odsState.addTypes(initArgs.getTypes());
2458  Block *bodyBlock = builder.createBlock(bodyRegion);
2459 
2460  // Starts with a list of user-provided loop arguments.
2461  for (Value v : initArgs)
2462  bodyBlock->addArgument(v.getType(), v.getLoc());
2463 
2464  // Follows by a list of used coordinates.
2465  for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2466  bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2467 
2468  // Ends with sparse iterator
2469  bodyBlock->addArgument(
2470  llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2471  odsState.location);
2472 }
2473 
2474 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2475  OpAsmParser::Argument iterator;
2477 
2478  SmallVector<OpAsmParser::Argument> iters, iterArgs;
2479  if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2480  return failure();
2481  if (iters.size() != 1)
2482  return parser.emitError(parser.getNameLoc(),
2483  "expected only one iterator/iteration space");
2484 
2485  iterArgs.append(iters);
2486  Region *body = result.addRegion();
2487  if (parser.parseRegion(*body, iterArgs))
2488  return failure();
2489 
2490  IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2491 
2492  // Parse the optional attribute list.
2493  if (parser.parseOptionalAttrDict(result.attributes))
2494  return failure();
2495 
2496  return success();
2497 }
2498 
2499 /// Prints the initialization list in the form of
2500 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2501 /// where 'inner' values are assumed to be region arguments and 'outer' values
2502 /// are regular SSA values.
2504  Block::BlockArgListType blocksArgs,
2505  ValueRange initializers,
2506  StringRef prefix = "") {
2507  assert(blocksArgs.size() == initializers.size() &&
2508  "expected same length of arguments and initializers");
2509  if (initializers.empty())
2510  return;
2511 
2512  p << prefix << '(';
2513  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
2514  p << std::get<0>(it) << " = " << std::get<1>(it);
2515  });
2516  p << ")";
2517 }
2518 
2519 template <typename SparseLoopOp>
2520 static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2521  if (op.getInitArgs().size() != op.getNumResults()) {
2522  return op.emitOpError(
2523  "mismatch in number of loop-carried values and defined values");
2524  }
2525  if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2526  return op.emitOpError("required out-of-bound coordinates");
2527 
2528  return success();
2529 }
2530 
2531 LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2532 LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2533 
2534 void IterateOp::print(OpAsmPrinter &p) {
2535  p << " " << getIterator() << " in " << getIterSpace();
2536  if (!getCrdUsedLvls().empty()) {
2537  p << " at(";
2538  printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2539  p << ")";
2540  }
2541  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2542 
2543  p << " : " << getIterSpace().getType() << " ";
2544  if (!getInitArgs().empty())
2545  p.printArrowTypeList(getInitArgs().getTypes());
2546 
2547  p << " ";
2548  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2549  /*printBlockTerminators=*/!getInitArgs().empty());
2550 }
2551 
2552 LogicalResult IterateOp::verifyRegions() {
2553  if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2554  return emitOpError("mismatch in iterator and iteration space type");
2555  if (getNumRegionIterArgs() != getNumResults())
2556  return emitOpError(
2557  "mismatch in number of basic block args and defined values");
2558 
2559  auto initArgs = getInitArgs();
2560  auto iterArgs = getRegionIterArgs();
2561  auto yieldVals = getYieldedValues();
2562  auto opResults = getResults();
2563  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2564  opResults.size()})) {
2565  return emitOpError() << "number mismatch between iter args and results.";
2566  }
2567 
2568  for (auto [i, init, iter, yield, ret] :
2569  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2570  if (init.getType() != ret.getType())
2571  return emitOpError() << "types mismatch between " << i
2572  << "th iter operand and defined value";
2573  if (iter.getType() != ret.getType())
2574  return emitOpError() << "types mismatch between " << i
2575  << "th iter region arg and defined value";
2576  if (yield.getType() != ret.getType())
2577  return emitOpError() << "types mismatch between " << i
2578  << "th yield value and defined value";
2579  }
2580 
2581  return success();
2582 }
2583 
2584 /// OpInterfaces' methods implemented by IterateOp.
2585 SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2586 
2587 MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2588  return getInitArgsMutable();
2589 }
2590 
2591 Block::BlockArgListType IterateOp::getRegionIterArgs() {
2592  return getRegion().getArguments().take_front(getNumRegionIterArgs());
2593 }
2594 
2595 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2596  return cast<sparse_tensor::YieldOp>(
2597  getRegion().getBlocks().front().getTerminator())
2598  .getResultsMutable();
2599 }
2600 
2601 std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2602 
2603 OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2604  return getInitArgs();
2605 }
2606 
2607 void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2609  // Both the operation itself and the region may be branching into the body
2610  // or back into the operation itself.
2611  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2612  // It is possible for loop not to enter the body.
2613  regions.push_back(RegionSuccessor(getResults()));
2614 }
2615 
2616 void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2617  ValueRange iterSpaces, ValueRange initArgs,
2618  unsigned numCases) {
2619  unsigned rank =
2620  cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2621  // All ones.
2622  I64BitSet set((1 << rank) - 1);
2623  // Generates all-zero case bits (they only serve as placeholders), which are
2624  // supposed to be overriden later. We need to preallocate all the regions as
2625  // mlir::Region cannot be dynamically added later after the operation is
2626  // created.
2627  SmallVector<int64_t> caseBits(numCases, 0);
2628  ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2629  return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2630  initArgs, set, cases,
2631  /*caseRegionsCount=*/numCases);
2632 }
2633 
2634 ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2635 
2636  SmallVector<Value> spaces;
2637  // The block argument list of each regions, it is arranged in the order of
2638  // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2640  if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2641  return failure();
2642 
2643  result.addAttribute("operandSegmentSizes",
2645  {static_cast<int32_t>(spaces.size()),
2646  static_cast<int32_t>(result.types.size())}));
2647 
2648  SmallVector<Attribute> cases;
2649  while (succeeded(parser.parseOptionalKeyword("case"))) {
2650  // Parse one region per case.
2651  I64BitSet definedItSet;
2653  if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2654  spaces.size(), OpAsmParser::Delimiter::None))
2655  return failure();
2656 
2657  cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2658 
2659  for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2660  // Resolve the iterator type based on the iteration space type.
2661  auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2662  definedIts[i].type = spaceTp.getIteratorType();
2663  }
2664  definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2665  Region *body = result.addRegion();
2666  if (parser.parseRegion(*body, definedIts))
2667  return failure();
2668 
2669  CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2670  }
2671 
2672  result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2673 
2674  // Parse the optional attribute list.
2675  if (parser.parseOptionalAttrDict(result.attributes))
2676  return failure();
2677 
2678  return success();
2679 }
2680 
2682  p << " (";
2683  llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2684  p << ")";
2685 
2686  if (!getCrdUsedLvls().empty()) {
2687  p << " at(";
2688  printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2689  p << ")";
2690  }
2691 
2692  printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2693 
2694  p << " : (" << getIterSpaces().getTypes() << ")";
2695  if (!getInitArgs().empty())
2696  p.printArrowTypeList(getInitArgs().getTypes());
2697 
2698  for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2699  p.printNewline();
2700  p << "case ";
2701  printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2702  getRegionDefinedSpace(idx));
2703  p << " ";
2704  p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2705  /*printBlockTerminators=*/!getInitArgs().empty());
2706  }
2707 }
2708 
2709 ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2710  return cast<sparse_tensor::YieldOp>(
2711  getRegion(regionIdx).getBlocks().front().getTerminator())
2712  .getResults();
2713 }
2714 
2715 LogicalResult CoIterateOp::verifyRegions() {
2716  for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2717  if (getNumRegionIterArgs() != getNumResults())
2718  return emitOpError(
2719  "mismatch in number of basic block args and defined values");
2720 
2721  auto initArgs = getInitArgs();
2722  auto iterArgs = getRegionIterArgs(r);
2723  auto yieldVals = getYieldedValues(r);
2724  auto opResults = getResults();
2725  if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2726  opResults.size()})) {
2727  return emitOpError()
2728  << "number mismatch between iter args and results on " << r
2729  << "th region";
2730  }
2731 
2732  for (auto [i, init, iter, yield, ret] :
2733  llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2734  if (init.getType() != ret.getType())
2735  return emitOpError()
2736  << "types mismatch between " << i
2737  << "th iter operand and defined value on " << r << "th region";
2738  if (iter.getType() != ret.getType())
2739  return emitOpError() << "types mismatch between " << i
2740  << "th iter region arg and defined value on " << r
2741  << "th region";
2742  if (yield.getType() != ret.getType())
2743  return emitOpError()
2744  << "types mismatch between " << i
2745  << "th yield value and defined value on " << r << "th region";
2746  }
2747  }
2748 
2749  auto cases = getRegionDefinedSpaces();
2750  llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2751  if (set.size() != getNumRegions())
2752  return emitOpError("contains duplicated cases.");
2753 
2754  return success();
2755 }
2756 
2757 SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2759  I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2760  for (Region &r : getCaseRegions())
2761  if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2762  ret.push_back(&r);
2763 
2764  return ret;
2765 }
2766 
2767 //===----------------------------------------------------------------------===//
2768 // Sparse Tensor Dialect Setups.
2769 //===----------------------------------------------------------------------===//
2770 
2771 /// Materialize a single constant operation from a given attribute value with
2772 /// the desired resultant type.
2774  Attribute value, Type type,
2775  Location loc) {
2776  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2777  return op;
2778  return nullptr;
2779 }
2780 
2781 namespace {
2782 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2784 
2785  AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2786  if (isa<SparseTensorEncodingAttr>(attr)) {
2787  os << "sparse";
2788  return AliasResult::OverridableAlias;
2789  }
2790  return AliasResult::NoAlias;
2791  }
2792 };
2793 } // namespace
2794 
2795 void SparseTensorDialect::initialize() {
2796  addInterface<SparseTensorAsmDialectInterface>();
2797  addAttributes<
2798 #define GET_ATTRDEF_LIST
2799 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2800  >();
2801  addTypes<
2802 #define GET_TYPEDEF_LIST
2803 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2804  >();
2805  addOperations<
2806 #define GET_OP_LIST
2807 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2808  >();
2809  declarePromisedInterfaces<
2810  bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2811  NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2812  ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2813 }
2814 
2815 #define GET_OP_CLASSES
2816 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2817 
2818 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:71
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
Definition: MeshOps.cpp:140
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
Definition: Storage.cpp:20
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
Definition: AffineExpr.h:68
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
Definition: AffineMap.cpp:369
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
The possible results of an alias query.
Definition: AliasAnalysis.h:26
@ NoAlias
The two locations do not alias at all.
Definition: AliasAnalysis.h:34
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
I64BitSet & set(unsigned i)
Definition: SparseTensor.h:83
A wrapper around RankedTensorType, which has three goals:
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool isWithCrdLT(LevelType lt)
Definition: Enums.h:431
bool isWithPosLT(LevelType lt)
Definition: Enums.h:432
bool isOrderedLT(LevelType lt)
Definition: Enums.h:425
std::string toMLIRString(LevelType lt)
Definition: Enums.h:447
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
Definition: Enums.h:421
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:39
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
Definition: Enums.h:442
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Definition: SparseTensor.h:46
llvm::hash_code hash_value(LevelType lt)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
Definition: Enums.h:414
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
Definition: Enums.h:402
bool isNOutOfMLT(LevelType lt)
Definition: Enums.h:424
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
Definition: SparseTensor.h:50
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
Definition: Enums.h:326
LevelType stripStorageIrrelevantProperties() const
Definition: Enums.h:299