MLIR  18.0.0git
SparseTensorDialect.cpp
Go to the documentation of this file.
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "Detail/DimLvlMapParser.h"
12 
16 
19 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define GET_ATTRDEF_CLASSES
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
30 
31 #define GET_TYPEDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
33 
34 using namespace mlir;
35 using namespace mlir::sparse_tensor;
36 
37 //===----------------------------------------------------------------------===//
38 // Additional convenience methods.
39 //===----------------------------------------------------------------------===//
40 
41 static constexpr bool acceptBitWidth(unsigned bitWidth) {
42  switch (bitWidth) {
43  case 0:
44  case 8:
45  case 16:
46  case 32:
47  case 64:
48  return true;
49  default:
50  return false;
51  }
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // StorageLayout
56 //===----------------------------------------------------------------------===//
57 
58 static constexpr Level kInvalidLevel = -1u;
59 static constexpr Level kInvalidFieldIndex = -1u;
60 static constexpr FieldIndex kDataFieldStartingIdx = 0;
61 
64  DimLevelType)>
65  callback) const {
66 #define RETURN_ON_FALSE(fidx, kind, lvl, dlt) \
67  if (!(callback(fidx, kind, lvl, dlt))) \
68  return;
69 
70  const auto lvlTypes = enc.getLvlTypes();
71  const Level lvlRank = enc.getLvlRank();
72  const Level cooStart = getCOOStart(enc);
73  const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
75  // Per-level storage.
76  for (Level l = 0; l < end; l++) {
77  const auto dlt = lvlTypes[l];
78  if (isDLTWithPos(dlt)) {
80  }
81  if (isDLTWithCrd(dlt)) {
83  }
84  }
85  // The values array.
88 
89  // Put metadata at the end.
92 
93 #undef RETURN_ON_FALSE
94 }
95 
97  SparseTensorType stt,
99  DimLevelType)>
100  callback) {
101  assert(stt.hasEncoding());
102  // Construct the basic types.
103  const Type crdType = stt.getCrdType();
104  const Type posType = stt.getPosType();
105  const Type eltType = stt.getElementType();
106 
107  const Type specType = StorageSpecifierType::get(stt.getEncoding());
108  // memref<? x pos> positions
109  const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
110  // memref<? x crd> coordinates
111  const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
112  // memref<? x eltType> values
113  const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
114 
116  [specType, posMemType, crdMemType, valMemType,
117  callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
118  Level lvl, DimLevelType dlt) -> bool {
119  switch (fieldKind) {
121  return callback(specType, fieldIdx, fieldKind, lvl, dlt);
123  return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
125  return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
127  return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
128  };
129  llvm_unreachable("unrecognized field kind");
130  });
131 }
132 
133 unsigned StorageLayout::getNumFields() const {
134  unsigned numFields = 0;
136  DimLevelType) -> bool {
137  numFields++;
138  return true;
139  });
140  return numFields;
141 }
142 
144  unsigned numFields = 0; // one value memref
146  DimLevelType) -> bool {
147  if (fidx >= kDataFieldStartingIdx)
148  numFields++;
149  return true;
150  });
151  numFields -= 1; // the last field is StorageSpecifier
152  assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
153  return numFields;
154 }
155 
156 std::pair<FieldIndex, unsigned>
158  std::optional<Level> lvl) const {
159  FieldIndex fieldIdx = kInvalidFieldIndex;
160  unsigned stride = 1;
161  if (kind == SparseTensorFieldKind::CrdMemRef) {
162  assert(lvl.has_value());
163  const Level cooStart = getCOOStart(enc);
164  const Level lvlRank = enc.getLvlRank();
165  if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
166  lvl = cooStart;
167  stride = lvlRank - cooStart;
168  }
169  }
170  foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
171  SparseTensorFieldKind fKind, Level fLvl,
172  DimLevelType dlt) -> bool {
173  if ((lvl && fLvl == lvl.value() && kind == fKind) ||
174  (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
175  fieldIdx = fIdx;
176  // Returns false to break the iteration.
177  return false;
178  }
179  return true;
180  });
181  assert(fieldIdx != kInvalidFieldIndex);
182  return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // TensorDialect Attribute Methods.
187 //===----------------------------------------------------------------------===//
188 
189 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
190  return isDynamic(v) ? std::nullopt
191  : std::make_optional(static_cast<uint64_t>(v));
192 }
193 
194 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
195  return getStatic(getOffset());
196 }
197 
198 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
199  return getStatic(getStride());
200 }
201 
202 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
203  return getStatic(getSize());
204 }
205 
206 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
207  return isDynamic(getOffset()) && isDynamic(getStride()) &&
208  isDynamic(getSize());
209 }
210 
211 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
212  return isDynamic(v) ? "?" : std::to_string(v);
213 }
214 
215 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
216  assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
217  os << '(';
218  os << getStaticString(getOffset());
219  os << ", ";
220  os << getStaticString(getSize());
221  os << ", ";
222  os << getStaticString(getStride());
223  os << ')';
224 }
225 
226 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
227  print(printer.getStream());
228 }
229 
230 static ParseResult parseOptionalStaticSlice(int64_t &result,
231  AsmParser &parser) {
232  auto parseResult = parser.parseOptionalInteger(result);
233  if (parseResult.has_value()) {
234  if (parseResult.value().succeeded() && result < 0) {
235  parser.emitError(
236  parser.getCurrentLocation(),
237  "expect positive value or ? for slice offset/size/stride");
238  return failure();
239  }
240  return parseResult.value();
241  }
242 
243  // Else, and '?' which represented dynamic slice
244  result = SparseTensorDimSliceAttr::kDynamic;
245  return parser.parseQuestion();
246 }
247 
248 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
249  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
250 
251  if (failed(parser.parseLParen()) ||
252  failed(parseOptionalStaticSlice(offset, parser)) ||
253  failed(parser.parseComma()) ||
254  failed(parseOptionalStaticSlice(size, parser)) ||
255  failed(parser.parseComma()) ||
256  failed(parseOptionalStaticSlice(stride, parser)) ||
257  failed(parser.parseRParen()))
258  return {};
259 
260  return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
261  offset, size, stride);
262 }
263 
266  int64_t offset, int64_t size, int64_t stride) {
267  if (!isDynamic(offset) && offset < 0)
268  return emitError() << "expect non-negative value or ? for slice offset";
269  if (!isDynamic(size) && size <= 0)
270  return emitError() << "expect positive value or ? for slice size";
271  if (!isDynamic(stride) && stride <= 0)
272  return emitError() << "expect positive value or ? for slice stride";
273  return success();
274 }
275 
277  unsigned bitwidth) {
278  if (bitwidth)
279  return IntegerType::get(ctx, bitwidth);
280  return IndexType::get(ctx);
281 }
282 
283 Type SparseTensorEncodingAttr::getPosType() const {
284  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
285  return detail::getIntegerOrIndexType(getContext(), getPosWidth());
286 }
287 
288 Type SparseTensorEncodingAttr::getCrdType() const {
289  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
290  return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
291 }
292 
293 SparseTensorEncodingAttr
294 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
295  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
296  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
297  getPosWidth(), getCrdWidth());
298 }
299 
300 SparseTensorEncodingAttr
301 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
302  return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
303 }
304 
305 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
306  return withDimToLvl(AffineMap());
307 }
308 
309 SparseTensorEncodingAttr
310 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
311  unsigned crdWidth) const {
312  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
313  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
314  getDimToLvl(), posWidth, crdWidth);
315 }
316 
317 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
318  return withBitWidths(0, 0);
319 }
320 
321 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
322  ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
323  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
324  getDimToLvl(), getPosWidth(),
325  getCrdWidth(), dimSlices);
326 }
327 
328 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
329  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
330 }
331 
332 bool SparseTensorEncodingAttr::isAllDense() const {
333  return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
334 }
335 
336 bool SparseTensorEncodingAttr::isAllOrdered() const {
337  return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT);
338 }
339 
340 bool SparseTensorEncodingAttr::isIdentity() const {
341  return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
342 }
343 
345  return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
346 }
347 
348 Dimension SparseTensorEncodingAttr::getDimRank() const {
349  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
350  const auto dimToLvl = getDimToLvl();
351  return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
352 }
353 
354 Level SparseTensorEncodingAttr::getLvlRank() const {
355  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
356  return getLvlTypes().size();
357 }
358 
359 DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
360  if (!getImpl())
361  return DimLevelType::Dense;
362  assert(l < getLvlRank() && "Level is out of bounds");
363  return getLvlTypes()[l];
364 }
365 
366 bool SparseTensorEncodingAttr::isSlice() const {
367  assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
368  return !getDimSlices().empty();
369 }
370 
371 SparseTensorDimSliceAttr
372 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
373  assert(isSlice() && "Is not a slice");
374  const auto dimSlices = getDimSlices();
375  assert(dim < dimSlices.size() && "Dimension is out of bounds");
376  return dimSlices[dim];
377 }
378 
379 std::optional<uint64_t>
380 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
381  return getDimSlice(dim).getStaticOffset();
382 }
383 
384 std::optional<uint64_t>
385 SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
386  return getDimSlice(dim).getStaticSize();
387 }
388 
389 std::optional<uint64_t>
390 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
391  return getDimSlice(dim).getStaticStride();
392 }
393 
394 std::optional<uint64_t>
395 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
396  // FIXME: `toOrigDim` is deprecated.
397  return getStaticDimSliceOffset(toOrigDim(*this, lvl));
398 }
399 
400 std::optional<uint64_t>
401 SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const {
402  // FIXME: `toOrigDim` is deprecated.
403  return getStaticDimSliceSize(toOrigDim(*this, lvl));
404 }
405 
406 std::optional<uint64_t>
407 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
408  // FIXME: `toOrigDim` is deprecated.
409  return getStaticDimSliceStride(toOrigDim(*this, lvl));
410 }
411 
426 
427 static std::optional<DimLevelType> parseDLT(StringRef str) {
428  for (DimLevelType dlt : validDLTs)
429  if (str == toMLIRString(dlt))
430  return dlt;
431  return std::nullopt;
432 }
433 
434 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
435 #define RETURN_ON_FAIL(stmt) \
436  if (failed(stmt)) { \
437  return {}; \
438  }
439 #define ERROR_IF(COND, MSG) \
440  if (COND) { \
441  parser.emitError(parser.getNameLoc(), MSG); \
442  return {}; \
443  }
444 
445  RETURN_ON_FAIL(parser.parseLess())
446  RETURN_ON_FAIL(parser.parseLBrace())
447 
448  // Process the data from the parsed dictionary value into struct-like data.
449  SmallVector<DimLevelType> lvlTypes;
450  SmallVector<SparseTensorDimSliceAttr> dimSlices;
451  AffineMap dimToLvl = {};
452  unsigned posWidth = 0;
453  unsigned crdWidth = 0;
454  StringRef attrName;
455  SmallVector<StringRef, 6> keys = {"lvlTypes", "dimToLvl", "posWidth",
456  "crdWidth", "dimSlices", "map"};
457  while (succeeded(parser.parseOptionalKeyword(&attrName))) {
458  // Detect admissible keyword.
459  auto *it = find(keys, attrName);
460  if (it == keys.end()) {
461  parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
462  return {};
463  }
464  unsigned keyWordIndex = it - keys.begin();
465  // Consume the `=` after keys
466  RETURN_ON_FAIL(parser.parseEqual())
467  // Dispatch on keyword.
468  switch (keyWordIndex) {
469  case 0: { // lvlTypes
470  Attribute attr;
471  RETURN_ON_FAIL(parser.parseAttribute(attr));
472  auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr);
473  ERROR_IF(!arrayAttr, "expected an array for lvlTypes")
474  for (auto i : arrayAttr) {
475  auto strAttr = llvm::dyn_cast<StringAttr>(i);
476  ERROR_IF(!strAttr, "expected a string value in lvlTypes")
477  auto strVal = strAttr.getValue();
478  if (auto optDLT = parseDLT(strVal)) {
479  lvlTypes.push_back(optDLT.value());
480  } else {
481  parser.emitError(parser.getNameLoc(), "unexpected level-type: ")
482  << strVal;
483  return {};
484  }
485  }
486  break;
487  }
488  case 1: { // dimToLvl
489  Attribute attr;
490  RETURN_ON_FAIL(parser.parseAttribute(attr))
491  auto affineAttr = llvm::dyn_cast<AffineMapAttr>(attr);
492  ERROR_IF(!affineAttr, "expected an affine map for dimToLvl")
493  dimToLvl = affineAttr.getValue();
494  break;
495  }
496  case 2: { // posWidth
497  Attribute attr;
498  RETURN_ON_FAIL(parser.parseAttribute(attr))
499  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
500  ERROR_IF(!intAttr, "expected an integral position bitwidth")
501  posWidth = intAttr.getInt();
502  break;
503  }
504  case 3: { // crdWidth
505  Attribute attr;
506  RETURN_ON_FAIL(parser.parseAttribute(attr))
507  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
508  ERROR_IF(!intAttr, "expected an integral index bitwidth")
509  crdWidth = intAttr.getInt();
510  break;
511  }
512  case 4: { // dimSlices
513  RETURN_ON_FAIL(parser.parseLSquare())
514  // Dispatches to DimSliceAttr to skip mnemonic
515  bool finished = false;
516  while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) {
517  auto sliceAttr = llvm::cast<SparseTensorDimSliceAttr>(attr);
518  dimSlices.push_back(sliceAttr);
519  if (parser.parseOptionalComma().failed()) {
520  finished = true;
521  break;
522  }
523  }
524  // Wrong when parsing slices
525  if (!finished)
526  return {};
527  RETURN_ON_FAIL(parser.parseRSquare())
528  break;
529  }
530  case 5: { // map (new STEA surface syntax)
531  ir_detail::DimLvlMapParser cParser(parser);
532  auto res = cParser.parseDimLvlMap();
533  RETURN_ON_FAIL(res);
534  // TODO: use DimLvlMap directly as storage representation, rather
535  // than converting things over.
536  const auto &dlm = *res;
537 
538  ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `map`")
539  const Level lvlRank = dlm.getLvlRank();
540  for (Level lvl = 0; lvl < lvlRank; lvl++)
541  lvlTypes.push_back(dlm.getLvlType(lvl));
542 
543  ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `map`")
544  const Dimension dimRank = dlm.getDimRank();
545  for (Dimension dim = 0; dim < dimRank; dim++)
546  dimSlices.push_back(dlm.getDimSlice(dim));
547  // NOTE: the old syntax requires an all-or-nothing approach to
548  // `dimSlices`; therefore, if any slice actually exists then we need
549  // to convert null-DSA into default/nop DSA.
550  const auto isDefined = [](SparseTensorDimSliceAttr slice) {
551  return static_cast<bool>(slice.getImpl());
552  };
553  if (llvm::any_of(dimSlices, isDefined)) {
554  const auto defaultSlice =
555  SparseTensorDimSliceAttr::get(parser.getContext());
556  for (Dimension dim = 0; dim < dimRank; dim++)
557  if (!isDefined(dimSlices[dim]))
558  dimSlices[dim] = defaultSlice;
559  } else {
560  dimSlices.clear();
561  }
562 
563  ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `map`")
564  dimToLvl = dlm.getDimToLvlMap(parser.getContext());
565  break;
566  }
567  } // switch
568  // Only last item can omit the comma.
569  if (parser.parseOptionalComma().failed())
570  break;
571  }
572 
573  RETURN_ON_FAIL(parser.parseRBrace())
574  RETURN_ON_FAIL(parser.parseGreater())
575 #undef ERROR_IF
576 #undef RETURN_ON_FAIL
577 
578  // Construct struct-like storage for attribute.
579  return parser.getChecked<SparseTensorEncodingAttr>(
580  parser.getContext(), lvlTypes, dimToLvl, posWidth, crdWidth, dimSlices);
581 }
582 
583 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
584  // Print the struct-like storage in dictionary fashion.
585  printer << "<{ lvlTypes = [ ";
586  llvm::interleaveComma(getLvlTypes(), printer, [&](DimLevelType dlt) {
587  printer << "\"" << toMLIRString(dlt) << "\"";
588  });
589  printer << " ]";
590  // Print remaining members only for non-default values.
591  if (!isIdentity())
592  printer << ", dimToLvl = affine_map<" << getDimToLvl() << ">";
593  if (getPosWidth())
594  printer << ", posWidth = " << getPosWidth();
595  if (getCrdWidth())
596  printer << ", crdWidth = " << getCrdWidth();
597  if (!getDimSlices().empty()) {
598  printer << ", dimSlices = [ ";
599  llvm::interleaveComma(getDimSlices(), printer,
600  [&](SparseTensorDimSliceAttr attr) {
601  // Calls SparseTensorDimSliceAttr::print directly to
602  // skip mnemonic.
603  attr.print(printer);
604  });
605  printer << " ]";
606  }
607 
608  printer << " }>";
609 }
610 
613  ArrayRef<DimLevelType> lvlTypes, AffineMap dimToLvl, unsigned posWidth,
614  unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
615  if (!acceptBitWidth(posWidth))
616  return emitError() << "unexpected position bitwidth: " << posWidth;
617  if (!acceptBitWidth(crdWidth))
618  return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
619  // Before we can check that the level-rank is consistent/coherent
620  // across all fields, we need to define it. The source-of-truth for
621  // the `getLvlRank` method is the length of the level-types array,
622  // since it must always be provided and have full rank; therefore we
623  // use that same source-of-truth here.
624  const Level lvlRank = lvlTypes.size();
625  if (lvlRank == 0)
626  return emitError() << "expected a non-empty array for lvlTypes";
627  // We save `dimRank` here because we'll also need it to verify `dimSlices`.
628  const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
629  if (dimToLvl) {
630  if (dimToLvl.getNumResults() != lvlRank)
631  return emitError()
632  << "level-rank mismatch between dimToLvl and lvlTypes: "
633  << dimToLvl.getNumResults() << " != " << lvlRank;
634  // TODO: The following is attempting to match the old error-conditions
635  // from prior to merging dimOrdering and higherOrdering into dimToLvl.
636  // That is, we currently require `dimToLvl` to be either a permutation
637  // (as when higherOrdering is the identity) or expansive (as per the
638  // constraints on higherOrdering). However, those constraints do
639  // not match the intended semantics of `dimToLvl`. As we improve the
640  // compiler to actually handle non-permutations, we need to update these
641  // checks to match what is actually supported. In particular, this is
642  // where we'll have to check that when `lvlToDim` is provided then it
643  // is indeed an inverse of `dimToLvl`, and when it isn't provided then
644  // it can be automatically inferred.
645  if (dimRank == lvlRank && !dimToLvl.isPermutation())
646  return emitError() << "expected a permutation affine map for dimToLvl";
647  if (dimRank > lvlRank)
648  return emitError() << "unexpected dimToLvl mapping from " << dimRank
649  << " to " << lvlRank;
650  }
651  if (!dimSlices.empty()) {
652  if (dimSlices.size() != dimRank)
653  return emitError()
654  << "dimension-rank mismatch between dimSlices and dimToLvl: "
655  << dimSlices.size() << " != " << dimRank;
656  // Compiler support for `dimSlices` currently requires that the two
657  // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
658  if (dimRank != lvlRank)
659  return emitError()
660  << "dimSlices expected dimension-rank to match level-rank: "
661  << dimRank << " != " << lvlRank;
662  }
663  return success();
664 }
665 
666 #define RETURN_FAILURE_IF_FAILED(X) \
667  if (failed(X)) { \
668  return failure(); \
669  }
670 
671 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
672  ArrayRef<DynSize> dimShape, Type elementType,
674  // Check structural integrity. In particular, this ensures that the
675  // level-rank is coherent across all the fields.
676  RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
677  getPosWidth(), getCrdWidth(), getDimSlices()))
678  // Check integrity with tensor type specifics. In particular, we
679  // need only check that the dimension-rank of the tensor agrees with
680  // the dimension-rank of the encoding.
681  const Dimension dimRank = dimShape.size();
682  if (dimRank == 0)
683  return emitError() << "expected non-scalar sparse tensor";
684  if (getDimRank() != dimRank)
685  return emitError()
686  << "dimension-rank mismatch between encoding and tensor shape: "
687  << getDimRank() << " != " << dimRank;
688  return success();
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // Convenience Methods.
693 //===----------------------------------------------------------------------===//
694 
695 SparseTensorEncodingAttr
697  if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
698  return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
699  if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
700  return mdtp.getEncoding();
701  return nullptr;
702 }
703 
704 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
705  Level startLvl, bool isUnique) {
706  if (!enc ||
707  !(enc.isCompressedLvl(startLvl) || enc.isCompressedWithHiLvl(startLvl)))
708  return false;
709  const Level lvlRank = enc.getLvlRank();
710  for (Level l = startLvl + 1; l < lvlRank; ++l)
711  if (!enc.isSingletonLvl(l))
712  return false;
713  // If isUnique is true, then make sure that the last level is unique,
714  // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
715  // (unique on the last singleton).
716  return !isUnique || enc.isUniqueLvl(lvlRank - 1);
717 }
718 
720  return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
721 }
722 
723 Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
724  // We only consider COO region with at least two levels for the purpose
725  // of AOS storage optimization.
726  const Level lvlRank = enc.getLvlRank();
727  if (lvlRank > 1)
728  for (Level l = 0; l < lvlRank - 1; l++)
729  if (isCOOType(enc, l, /*isUnique=*/false))
730  return l;
731  return lvlRank;
732 }
733 
734 // Helpers to setup a COO type.
735 RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
736  AffineMap lvlPerm,
737  bool ordered) {
738  const SparseTensorType src(rtt);
739  // TODO: This assertion is to match the behavior from before we merged
740  // dimOrdering and higherOrdering into dimToLvl. However, there's no
741  // in-principle reason to require this. (wrengr has a commit in the
742  // wings to fix this.)
743  assert(src.isPermutation());
744  const Level lvlRank = src.getLvlRank();
745  SmallVector<DimLevelType> lvlTypes;
746  lvlTypes.reserve(lvlRank);
747 
748  // An unordered and non-unique compressed level at beginning.
749  // If this is also the last level, then it is unique.
750  lvlTypes.push_back(
751  *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
752  if (lvlRank > 1) {
753  // TODO: it is actually ordered at the level for ordered input.
754  // Followed by unordered non-unique n-2 singleton levels.
755  std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
756  *buildLevelType(LevelFormat::Singleton, ordered, false));
757  // Ends by a unique singleton level unless the lvlRank is 1.
758  lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
759  }
760 
761  // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
762  // largest one among them) in the original operation instead of using the
763  // default value.
764  unsigned posWidth = src.getPosWidth();
765  unsigned crdWidth = src.getCrdWidth();
766  auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
767  posWidth, crdWidth);
768  return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
769 }
770 
771 RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
772  bool ordered) {
774  src, AffineMap::getMultiDimIdentityMap(src.getRank(), src.getContext()),
775  ordered);
776 }
777 
778 // TODO: Remove this definition once all use-sites have been fixed to
779 // properly handle non-permutations.
780 Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
781  Level l) {
782  if (enc) {
783  if (const auto dimToLvl = enc.getDimToLvl()) {
784  assert(enc.isPermutation());
785  return dimToLvl.getDimPosition(l);
786  }
787  }
788  return l;
789 }
790 
791 // TODO: Remove this definition once all use-sites have been fixed to
792 // properly handle non-permutations.
793 Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
794  Dimension d) {
795  if (enc) {
796  if (const auto dimToLvl = enc.getDimToLvl()) {
797  assert(enc.isPermutation());
798  auto maybePos =
799  dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
800  assert(maybePos.has_value());
801  return *maybePos;
802  }
803  }
804  return d;
805 }
806 
807 // TODO: Remove this definition once all use-sites have been fixed to
808 // properly handle non-permutations.
809 Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) {
810  const auto enc = getSparseTensorEncoding(type);
811  assert(l < enc.getLvlRank());
812  return toOrigDim(enc, l);
813 }
814 
815 // TODO: Remove this definition once all use-sites have been fixed to
816 // properly handle non-permutations.
818  assert(d < static_cast<Dimension>(type.getRank()));
819  return toStoredDim(getSparseTensorEncoding(type), d);
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // SparseTensorDialect Types.
824 //===----------------------------------------------------------------------===//
825 
826 /// We normalized sparse tensor encoding attribute by always using
827 /// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
828 /// as other variants) lead to the same storage specifier type, and stripping
829 /// irrelevant fields that do not alter the sparse tensor memory layout.
830 static SparseTensorEncodingAttr
831 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
833  for (auto dlt : enc.getLvlTypes())
834  dlts.push_back(*buildLevelType(*getLevelFormat(dlt), true, true));
835 
837  enc.getContext(), dlts,
838  AffineMap(), // dimToLvl (irrelevant to storage specifier)
839  // Always use `index` for memSize and lvlSize instead of reusing
840  // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
841  // value for different bitwidth, it also avoids casting between index and
842  // integer (returned by DimOp)
843  0, 0, enc.getDimSlices());
844 }
845 
846 StorageSpecifierType
847 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
848  return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // SparseTensorDialect Operations.
853 //===----------------------------------------------------------------------===//
854 
855 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
856  return success(lvl < getSparseTensorType(tensor).getLvlRank());
857 }
858 
859 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
860  const Type etp = getMemRefType(mem).getElementType();
861  return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
862 }
863 
865  StorageSpecifierKind mdKind, std::optional<Level> lvl,
867  if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
868  return op->emitError(
869  "redundant level argument for querying value memory size");
870  }
871 
872  const auto enc = md.getType().getEncoding();
873  const Level lvlRank = enc.getLvlRank();
874 
875  if (mdKind == StorageSpecifierKind::DimOffset ||
876  mdKind == StorageSpecifierKind::DimStride)
877  if (!enc.isSlice())
878  return op->emitError("requested slice data on non-slice tensor");
879 
880  if (mdKind != StorageSpecifierKind::ValMemSize) {
881  if (!lvl)
882  return op->emitError("missing level argument");
883 
884  const Level l = lvl.value();
885  if (l >= lvlRank)
886  return op->emitError("requested level is out of bounds");
887 
888  if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
889  return op->emitError(
890  "requested position memory size on a singleton level");
891  }
892  return success();
893 }
894 
896  switch (kind) {
898  return stt.getCrdType();
900  return stt.getPosType();
902  return stt.getElementType();
904  return nullptr;
905  }
906  llvm_unreachable("Unrecognizable FieldKind");
907 }
908 
909 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
910  SparseTensorType stt,
911  RankedTensorType valTp,
912  TypeRange lvlTps) {
913  if (requiresStaticShape && !stt.hasStaticDimShape())
914  return op->emitError("the sparse-tensor must have static shape");
915  if (!stt.hasEncoding())
916  return op->emitError("the sparse-tensor must have an encoding attribute");
917  if (!stt.isIdentity())
918  return op->emitError("the sparse-tensor must have the identity mapping");
919 
920  // Verifies the trailing COO.
921  Level cooStartLvl = getCOOStart(stt.getEncoding());
922  if (cooStartLvl < stt.getLvlRank()) {
923  // We only supports trailing COO for now, must be the last input.
924  auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
925  // The coordinates should be in shape of <? x rank>
926  unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
927  if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
928  op->emitError("input/output trailing COO level-ranks don't match");
929  }
930  }
931 
932  // Verifies that all types match.
933  StorageLayout layout(stt.getEncoding());
934  if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
935  return op->emitError("inconsistent number of fields between input/output");
936 
937  unsigned idx = 0;
938  bool misMatch = false;
939  layout.foreachField([&idx, &misMatch, stt, valTp,
940  lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
941  Level lvl, DimLevelType dlt) -> bool {
943  return true;
944 
945  Type inputTp = nullptr;
946  if (fKind == SparseTensorFieldKind::ValMemRef) {
947  inputTp = valTp;
948  } else {
949  assert(fid == idx && stt.getLvlType(lvl) == dlt);
950  inputTp = lvlTps[idx++];
951  }
952  // The input element type and expected element type should match.
953  Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
954  Type expElemTp = getFieldElemType(stt, fKind);
955  if (inpElemTp != expElemTp) {
956  misMatch = true;
957  return false; // to terminate the iteration
958  }
959  return true;
960  });
961 
962  if (misMatch)
963  return op->emitError("input/output element-types don't match");
964  return success();
965 }
966 
968  const auto valuesTp = getRankedTensorType(getValues());
969  const auto lvlsTp = getLevels().getTypes();
970  const auto resTp = getSparseTensorType(getResult());
971  return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
972 }
973 
975  if (getOutValues().getType() != getRetValues().getType())
976  return emitError("output values and return value type mismatch");
977 
978  for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
979  if (ot.getType() != rt.getType())
980  return emitError("output levels and return levels type mismatch");
981 
982  const auto valuesTp = getRankedTensorType(getRetValues());
983  const auto lvlsTp = getRetLevels().getTypes();
984  const auto srcTp = getSparseTensorType(getTensor());
985  return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
986 }
987 
989  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
990  if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
991  if (tp1.getRank() != tp2.getRank())
992  return emitError("unexpected conversion mismatch in rank");
993  auto dstEnc =
994  llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
995  if (dstEnc && dstEnc.isSlice())
996  return emitError("cannot convert to a sparse tensor slice");
997 
998  auto shape1 = tp1.getShape();
999  auto shape2 = tp2.getShape();
1000  // Accept size matches between the source and the destination type
1001  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1002  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1003  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1004  if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1005  return emitError("unexpected conversion mismatch in dimension ") << d;
1006  return success();
1007  }
1008  }
1009  return emitError("unexpected type in convert");
1010 }
1011 
1012 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1013  Type dstType = getType();
1014  // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
1015  // convert for codegen to remove. This is because we use trivial
1016  // sparse-to-sparse convert to tell bufferization that the sparse codegen
1017  // will expand the tensor buffer into sparse tensor storage.
1018  if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType())
1019  return getSource();
1020  return {};
1021 }
1022 
1024  auto e = getSparseTensorEncoding(getTensor().getType());
1025  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1026  return emitError("requested level is out of bounds");
1027  if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
1028  return emitError("unexpected type for positions");
1029  return success();
1030 }
1031 
1033  auto e = getSparseTensorEncoding(getTensor().getType());
1034  if (failed(lvlIsInBounds(getLevel(), getTensor())))
1035  return emitError("requested level is out of bounds");
1036  if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
1037  return emitError("unexpected type for coordinates");
1038  return success();
1039 }
1040 
1042  auto e = getSparseTensorEncoding(getTensor().getType());
1043  if (getCOOStart(e) >= e.getLvlRank())
1044  return emitError("expected sparse tensor with a COO region");
1045  return success();
1046 }
1047 
1049  auto ttp = getRankedTensorType(getTensor());
1050  auto mtp = getMemRefType(getResult());
1051  if (ttp.getElementType() != mtp.getElementType())
1052  return emitError("unexpected mismatch in element types");
1053  return success();
1054 }
1055 
1057  auto rank = getRankedTensorType(getSlice()).getRank();
1058  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1059  return emitError("requested dimension out of bound");
1060  return success();
1061 }
1062 
1064  auto rank = getRankedTensorType(getSlice()).getRank();
1065  if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1066  return emitError("requested dimension out of bound");
1067  return success();
1068 }
1069 
1072  getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
1073  return success();
1074 }
1075 
1076 template <typename SpecifierOp>
1077 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1078  return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1079 }
1080 
1081 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1082  const StorageSpecifierKind kind = getSpecifierKind();
1083  const auto lvl = getLevel();
1084  for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1085  if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1086  return op.getValue();
1087  return {};
1088 }
1089 
1092  getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
1093  return success();
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // TensorDialect Linalg.Generic Operations.
1098 //===----------------------------------------------------------------------===//
1099 
1100 template <class T>
1102  const char *regionName,
1103  TypeRange inputTypes, Type outputType) {
1104  unsigned numArgs = region.getNumArguments();
1105  unsigned expectedNum = inputTypes.size();
1106  if (numArgs != expectedNum)
1107  return op->emitError() << regionName << " region must have exactly "
1108  << expectedNum << " arguments";
1109 
1110  for (unsigned i = 0; i < numArgs; i++) {
1111  Type typ = region.getArgument(i).getType();
1112  if (typ != inputTypes[i])
1113  return op->emitError() << regionName << " region argument " << (i + 1)
1114  << " type mismatch";
1115  }
1116  Operation *term = region.front().getTerminator();
1117  YieldOp yield = dyn_cast<YieldOp>(term);
1118  if (!yield)
1119  return op->emitError() << regionName
1120  << " region must end with sparse_tensor.yield";
1121  if (!yield.getResult() || yield.getResult().getType() != outputType)
1122  return op->emitError() << regionName << " region yield type mismatch";
1123 
1124  return success();
1125 }
1126 
1128  NamedAttrList attrs = (*this)->getAttrs();
1129  Type leftType = getX().getType();
1130  Type rightType = getY().getType();
1131  Type outputType = getOutput().getType();
1132  Region &overlap = getOverlapRegion();
1133  Region &left = getLeftRegion();
1134  Region &right = getRightRegion();
1135 
1136  // Check correct number of block arguments and return type for each
1137  // non-empty region.
1138  if (!overlap.empty()) {
1140  this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
1141  }
1142  if (!left.empty()) {
1144  verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
1145  } else if (getLeftIdentity()) {
1146  if (leftType != outputType)
1147  return emitError("left=identity requires first argument to have the same "
1148  "type as the output");
1149  }
1150  if (!right.empty()) {
1152  this, right, "right", TypeRange{rightType}, outputType))
1153  } else if (getRightIdentity()) {
1154  if (rightType != outputType)
1155  return emitError("right=identity requires second argument to have the "
1156  "same type as the output");
1157  }
1158  return success();
1159 }
1160 
1162  Type inputType = getX().getType();
1163  Type outputType = getOutput().getType();
1164 
1165  // Check correct number of block arguments and return type for each
1166  // non-empty region.
1167  Region &present = getPresentRegion();
1168  if (!present.empty()) {
1170  this, present, "present", TypeRange{inputType}, outputType))
1171  }
1172  Region &absent = getAbsentRegion();
1173  if (!absent.empty()) {
1175  verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
1176  }
1177  return success();
1178 }
1179 
1181  const auto dstTp = getSparseTensorType(*this);
1182  const Dimension concatDim = getDimension();
1183  const Dimension dimRank = dstTp.getDimRank();
1184 
1185  if (getInputs().size() <= 1)
1186  return emitError("Need at least two tensors to concatenate.");
1187 
1188  if (concatDim >= dimRank)
1189  return emitError(llvm::formatv(
1190  "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1191  concatDim, dimRank));
1192 
1193  for (const auto &it : llvm::enumerate(getInputs())) {
1194  const auto i = it.index();
1195  const auto srcTp = getSparseTensorType(it.value());
1196  if (srcTp.hasDynamicDimShape())
1197  return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1198  const Dimension srcDimRank = srcTp.getDimRank();
1199  if (srcDimRank != dimRank)
1200  return emitError(
1201  llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1202  "from the output tensor (rank={2}).",
1203  i, srcDimRank, dimRank));
1204  }
1205 
1206  for (Dimension d = 0; d < dimRank; d++) {
1207  const DynSize dstSh = dstTp.getDimShape()[d];
1208  if (d == concatDim) {
1209  if (!ShapedType::isDynamic(dstSh)) {
1210  // If we reach here, then all inputs have static shapes. So we
1211  // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1212  // to avoid redundant assertions in the loop.
1213  StaticSize sumSz = 0;
1214  for (const auto src : getInputs())
1215  sumSz += getSparseTensorType(src).getDimShape()[d];
1216  // If all dimension are statically known, the sum of all the input
1217  // dimensions should be equal to the output dimension.
1218  if (sumSz != dstSh)
1219  return emitError(
1220  "The concatenation dimension of the output tensor should be the "
1221  "sum of all the concatenation dimensions of the input tensors.");
1222  }
1223  } else {
1224  DynSize prev = dstSh;
1225  for (const auto src : getInputs()) {
1226  const auto sh = getSparseTensorType(src).getDimShape()[d];
1227  if (!ShapedType::isDynamic(prev) && sh != prev)
1228  return emitError("All dimensions (expect for the concatenating one) "
1229  "should be equal.");
1230  prev = sh;
1231  }
1232  }
1233  }
1234 
1235  return success();
1236 }
1237 
1239  const auto stt = getSparseTensorType(getTensor());
1240  if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
1241  return emitOpError("incorrect number of coordinates");
1242  return success();
1243 }
1244 
1245 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1246  Value curSize, Value inBuffer, Value value) {
1247  build(builder, result, curSize, inBuffer, value, Value());
1248 }
1249 
1251  if (Value n = getN()) {
1252  std::optional<int64_t> nValue = getConstantIntValue(n);
1253  if (nValue && nValue.value() < 1)
1254  return emitOpError("n must be not less than 1");
1255  }
1256  return success();
1257 }
1258 
1260  const auto stt = getSparseTensorType(getTensor());
1261  if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1262  return emitOpError("incorrect number of coordinates");
1263  return success();
1264 }
1265 
1266 void ForeachOp::build(
1267  OpBuilder &builder, OperationState &result, Value tensor,
1268  ValueRange initArgs, AffineMapAttr order,
1270  bodyBuilder) {
1271  build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1272  // Builds foreach body.
1273  if (!bodyBuilder)
1274  return;
1275  const auto stt = getSparseTensorType(tensor);
1276  const Dimension dimRank = stt.getDimRank();
1277 
1278  // Starts with `dimRank`-many coordinates.
1279  SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1280  // Followed by one value.
1281  blockArgTypes.push_back(stt.getElementType());
1282  // Followed by the reduction variables.
1283  blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1284 
1285  SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1286 
1287  OpBuilder::InsertionGuard guard(builder);
1288  auto &region = *result.regions.front();
1289  Block *bodyBlock =
1290  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1291  bodyBuilder(builder, result.location,
1292  bodyBlock->getArguments().slice(0, dimRank),
1293  bodyBlock->getArguments()[dimRank],
1294  bodyBlock->getArguments().drop_front(dimRank + 1));
1295 }
1296 
1298  const auto t = getSparseTensorType(getTensor());
1299  const Dimension dimRank = t.getDimRank();
1300  const auto args = getBody()->getArguments();
1301 
1302  if (getOrder().has_value() &&
1303  (t.getEncoding() || !getOrder()->isPermutation()))
1304  return emitError("Only support permuted order on non encoded dense tensor");
1305 
1306  if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
1307  return emitError("Unmatched number of arguments in the block");
1308 
1309  if (getNumResults() != getInitArgs().size())
1310  return emitError("Mismatch in number of init arguments and results");
1311 
1312  if (getResultTypes() != getInitArgs().getTypes())
1313  return emitError("Mismatch in types of init arguments and results");
1314 
1315  // Cannot mark this const, because the getters aren't.
1316  auto yield = cast<YieldOp>(getBody()->getTerminator());
1317  if (yield.getNumOperands() != getNumResults() ||
1318  yield.getOperands().getTypes() != getResultTypes())
1319  return emitError("Mismatch in types of yield values and results");
1320 
1321  const auto iTp = IndexType::get(getContext());
1322  for (Dimension d = 0; d < dimRank; d++)
1323  if (args[d].getType() != iTp)
1324  emitError(
1325  llvm::formatv("Expecting Index type for argument at index {0}", d));
1326 
1327  const auto elemTp = t.getElementType();
1328  const auto valueTp = args[dimRank].getType();
1329  if (elemTp != valueTp)
1330  emitError(llvm::formatv("Unmatched element type between input tensor and "
1331  "block argument, expected:{0}, got: {1}",
1332  elemTp, valueTp));
1333  return success();
1334 }
1335 
1337  Type inputType = getX().getType();
1338  // Check correct number of block arguments and return type.
1339  Region &formula = getRegion();
1341  this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
1342  return success();
1343 }
1344 
1345 LogicalResult SelectOp::verify() {
1346  Builder b(getContext());
1347  Type inputType = getX().getType();
1348  Type boolType = b.getI1Type();
1349  // Check correct number of block arguments and return type.
1350  Region &formula = getRegion();
1351  RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
1352  TypeRange{inputType}, boolType))
1353  return success();
1354 }
1355 
1356 LogicalResult SortCooOp::verify() {
1357  AffineMap xPerm = getPermMap();
1358  uint64_t nx = xPerm.getNumDims();
1359  if (nx < 1)
1360  emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1361 
1362  if (!xPerm.isPermutation())
1363  emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1364 
1365  std::optional<int64_t> cn = getConstantIntValue(getN());
1366  // We can't check the size of the buffers when n or buffer dimensions aren't
1367  // compile-time constants.
1368  if (!cn)
1369  return success();
1370 
1371  uint64_t n = cn.value();
1372  uint64_t ny = 0;
1373  if (auto nyAttr = getNyAttr()) {
1374  ny = nyAttr.getInt();
1375  }
1376 
1377  // FIXME: update the types of variables used in expressions bassed as
1378  // the `minSize` argument, to avoid implicit casting at the callsites
1379  // of this lambda.
1380  const auto checkDim = [&](Value v, StaticSize minSize, const char *message) {
1381  const DynSize sh = getMemRefType(v).getShape()[0];
1382  if (!ShapedType::isDynamic(sh) && sh < minSize)
1383  emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1384  };
1385 
1386  checkDim(getXy(), n * (nx + ny),
1387  "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1388 
1389  for (Value opnd : getYs()) {
1390  checkDim(opnd, n, "Expected dimension(y) >= n");
1391  }
1392 
1393  return success();
1394 }
1395 
1397  // Check for compatible parent.
1398  auto *parentOp = (*this)->getParentOp();
1399  if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
1400  isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
1401  isa<ForeachOp>(parentOp))
1402  return success();
1403 
1404  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1405  "reduce, select or foreach");
1406 }
1407 
1408 #undef RETURN_FAILURE_IF_FAILED
1409 
1410 //===----------------------------------------------------------------------===//
1411 // TensorDialect Methods.
1412 //===----------------------------------------------------------------------===//
1413 
1414 void SparseTensorDialect::initialize() {
1415  addAttributes<
1416 #define GET_ATTRDEF_LIST
1417 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1418  >();
1419  addTypes<
1420 #define GET_TYPEDEF_LIST
1421 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1422  >();
1423  addOperations<
1424 #define GET_OP_LIST
1425 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1426  >();
1427 }
1428 
1429 #define GET_OP_CLASSES
1430 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1431 
1432 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static bool isPermutation(std::vector< PermutationTy > permutation)
Definition: IRAffine.cpp:58
static MLIRContext * getContext(OpFoldResult val)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult verifyNumBlockArgs(T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique DLT such that "compress...
static const DimLevelType validDLTs[]
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static constexpr Level kInvalidFieldIndex
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
#define RETURN_ON_FAIL(stmt)
#define RETURN_ON_FALSE(fidx, kind, lvl, dlt)
static std::optional< DimLevelType > parseDLT(StringRef str)
#define RETURN_FAILURE_IF_FAILED(X)
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
#define ERROR_IF(COND, MSG)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:358
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:275
std::optional< unsigned > getResultPosition(AffineExpr input) const
Extracts the first result position where input dimension resides.
Definition: AffineMap.cpp:362
unsigned getNumDims() const
Definition: AffineMap.cpp:337
unsigned getNumResults() const
Definition: AffineMap.cpp:345
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:564
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:80
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IndexType getIndexType()
Definition: Builders.cpp:71
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
A wrapper around RankedTensorType, which has three goals:
DimLevelType getLvlType(Level l) const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
ArrayRef< DynSize > getDimShape() const
Returns the dimension-shape.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
bool isPermutation() const
Returns true if the dimToLvl mapping is a permutation.
SparseTensorEncodingAttr getEncoding() const
Returns the encoding (or the null-attribute for dense-tensors).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, DimLevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth)
constexpr std::optional< DimLevelType > buildLevelType(LevelFormat lf, bool ordered, bool unique)
Convert a LevelFormat to its corresponding DimLevelType with the given properties.
Definition: Enums.h:320
constexpr std::optional< LevelFormat > getLevelFormat(DimLevelType dlt)
Convert a DimLevelType to its corresponding LevelFormat.
Definition: Enums.h:308
Level getCOOStart(SparseTensorEncodingAttr enc)
Returns the starting level for a trailing COO region that spans at least two levels.
DEPRECATED Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d)
[deprecated] Convenience method to translate the given dimension to the corresponding level.
RankedTensorType getCOOFromType(RankedTensorType src, bool ordered)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers, and dimension-ranks.
Definition: SparseTensor.h:40
SparseTensorType getSparseTensorType(T t)
Convenience method to abbreviate wrapping getRankedTensorType.
uint64_t Level
The type of level identifiers, and level-ranks.
Definition: SparseTensor.h:46
constexpr bool isDenseDLT(DimLevelType dlt)
Check if the DimLevelType is dense.
Definition: Enums.h:265
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:96
int64_t DynSize
The type for individual components of a compile-time shape.
Definition: SparseTensor.h:52
bool isUniqueCOOType(Type tp)
Returns true iff the given type is a COO type where the last level is unique.
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique)
Returns true iff the given sparse tensor encoding attribute has a trailing COO region starting at the...
DEPRECATED Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l)
[deprecated] Convenience method to translate the given level to the corresponding dimension.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src, AffineMap ordering, bool ordered)
Helpers to setup a COO type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:104
DimLevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:175
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, DimLevelType)>)
int64_t StaticSize
The type for individual components of a compile-time shape which are known not to be ShapedType::kDyn...
Definition: SparseTensor.h:56
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
constexpr const char * toMLIRString(DimLevelType dlt)
Returns string representation of the given dimension level type.
Definition: Enums.h:212
constexpr bool isDLTWithPos(DimLevelType dlt)
Convenience method to query whether a given DLT needs both position and coordinates array or only coo...
Definition: SparseTensor.h:116
constexpr bool isDLTWithCrd(DimLevelType dlt)
Definition: SparseTensor.h:119
constexpr bool isOrderedDLT(DimLevelType dlt)
Check if the DimLevelType is ordered (regardless of storage format).
Definition: Enums.h:297
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:489
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SetVector< Operation * > getSlice(Operation *op, BackwardSliceOptions backwardSliceOptions={}, ForwardSliceOptions forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.