26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
41 #define GET_TYPEDEF_CLASSES
42 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
79 if (dimShape.has_value()) {
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
88 memrefShape.push_back(ShapedType::kDynamic);
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
111 for (
Level l = 0; l < lvlRank; ) {
112 const auto lt = lvlTypes[l];
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
125 l = cooSegsRef.front().lvlRange.second;
131 cooSegsRef = cooSegsRef.drop_front();
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
179 llvm_unreachable(
"unrecognized field kind");
184 unsigned numFields = 0;
194 unsigned numFields = 0;
206 std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl)
const {
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217 stride = lvlRank - cooStart;
223 if ((lvl && fLvl == lvl.value() &&
kind == fKind) ||
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
239 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(
static_cast<uint64_t
>(v));
244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset()
const {
245 return getStatic(getOffset());
248 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride()
const {
252 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize()
const {
253 return getStatic(getSize());
256 bool SparseTensorDimSliceAttr::isCompletelyDynamic()
const {
257 return isDynamic(getOffset()) && isDynamic(
getStride()) &&
258 isDynamic(getSize());
261 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ?
"?" : std::to_string(v);
266 assert(getImpl() &&
"Uninitialized SparseTensorDimSliceAttr");
268 os << getStaticString(getOffset());
270 os << getStaticString(getSize());
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
287 "expect positive value or ? for slice offset/size/stride");
290 return parseResult.value();
294 result = SparseTensorDimSliceAttr::kDynamic;
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
311 offset, size, stride);
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() <<
"expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() <<
"expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() <<
"expect positive value or ? for slice stride";
326 SparseTensorEncodingAttr
327 SparseTensorEncodingAttr::withDimToLvl(
AffineMap dimToLvl)
const {
328 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
334 SparseTensorEncodingAttr
335 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc)
const {
336 return withDimToLvl(enc ? enc.getDimToLvl() :
AffineMap());
339 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl()
const {
343 SparseTensorEncodingAttr
344 SparseTensorEncodingAttr::withBitWidths(
unsigned posWidth,
345 unsigned crdWidth)
const {
346 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
352 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths()
const {
353 return withBitWidths(0, 0);
356 SparseTensorEncodingAttr
357 SparseTensorEncodingAttr::withExplicitVal(
Attribute explicitVal)
const {
358 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
364 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal()
const {
368 SparseTensorEncodingAttr
369 SparseTensorEncodingAttr::withImplicitVal(
Attribute implicitVal)
const {
370 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
376 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal()
const {
380 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices()
const {
391 uint64_t SparseTensorEncodingAttr::getBatchLvlRank()
const {
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
398 return !getImpl() || llvm::all_of(getLvlTypes(),
isDenseLT);
401 bool SparseTensorEncodingAttr::isAllOrdered()
const {
402 return !getImpl() || llvm::all_of(getLvlTypes(),
isOrderedLT);
405 Type SparseTensorEncodingAttr::getCrdElemType()
const {
413 Type SparseTensorEncodingAttr::getPosElemType()
const {
421 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
427 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
433 bool SparseTensorEncodingAttr::isIdentity()
const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 Dimension SparseTensorEncodingAttr::getDimRank()
const {
442 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.
getNumDims() : getLvlRank();
447 Level SparseTensorEncodingAttr::getLvlRank()
const {
448 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
455 assert(l < getLvlRank() &&
"Level is out of bounds");
456 return getLvlTypes()[l];
459 bool SparseTensorEncodingAttr::isSlice()
const {
460 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
464 SparseTensorDimSliceAttr
465 SparseTensorEncodingAttr::getDimSlice(
Dimension dim)
const {
466 assert(isSlice() &&
"Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() &&
"Dimension is out of bounds");
469 return dimSlices[dim];
472 std::optional<uint64_t>
473 SparseTensorEncodingAttr::getStaticDimSliceOffset(
Dimension dim)
const {
474 return getDimSlice(dim).getStaticOffset();
477 std::optional<uint64_t>
478 SparseTensorEncodingAttr::getStaticDimSliceStride(
Dimension dim)
const {
479 return getDimSlice(dim).getStaticStride();
482 std::optional<uint64_t>
483 SparseTensorEncodingAttr::getStaticLvlSliceOffset(
Level lvl)
const {
484 return getStaticDimSliceOffset(
toDim(*
this, lvl));
487 std::optional<uint64_t>
488 SparseTensorEncodingAttr::getStaticLvlSliceStride(
Level lvl)
const {
489 return getStaticDimSliceStride(
toDim(*
this, lvl));
494 CrdTransDirectionKind dir)
const {
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
504 for (
unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ?
toDim(*
this, r)
507 ret.push_back(srcShape[trans]);
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
533 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534 ret.push_back(c.getValue() + 1);
536 if (
auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
540 if (
auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541 ret.push_back(bound.getValue());
545 ret.push_back(ShapedType::kDynamic);
548 assert(ret.size() == rank);
555 CrdTransDirectionKind dir)
const {
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563 CrdTranslateOp::create(builder, loc, retType, crds, dir, *
this);
564 return transOp.getOutCrds();
579 unsigned posWidth = 0;
580 unsigned crdWidth = 0;
585 "explicitVal",
"implicitVal"};
588 auto *it = find(keys, attrName);
589 if (it == keys.end()) {
593 unsigned keyWordIndex = it - keys.begin();
598 switch (keyWordIndex) {
601 auto res = cParser.parseDimLvlMap();
604 const auto &dlm = *res;
606 const Level lvlRank = dlm.getLvlRank();
607 for (
Level lvl = 0; lvl < lvlRank; lvl++)
608 lvlTypes.push_back(dlm.getLvlType(lvl));
610 const Dimension dimRank = dlm.getDimRank();
611 for (
Dimension dim = 0; dim < dimRank; dim++)
612 dimSlices.push_back(dlm.getDimSlice(dim));
616 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617 return static_cast<bool>(slice.getImpl());
619 if (llvm::any_of(dimSlices, isDefined)) {
620 const auto defaultSlice =
622 for (
Dimension dim = 0; dim < dimRank; dim++)
623 if (!isDefined(dimSlices[dim]))
624 dimSlices[dim] = defaultSlice;
629 dimToLvl = dlm.getDimToLvlMap(parser.
getContext());
630 lvlToDim = dlm.getLvlToDimMap(parser.
getContext());
637 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
640 "expected an integral position bitwidth");
643 posWidth = intAttr.getInt();
650 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
653 "expected an integral index bitwidth");
656 crdWidth = intAttr.getInt();
663 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
664 explicitVal = result;
665 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
666 explicitVal = result;
667 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668 explicitVal = result;
671 "expected a numeric value for explicitVal");
680 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
681 implicitVal = result;
682 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
683 implicitVal = result;
684 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685 implicitVal = result;
688 "expected a numeric value for implicitVal");
706 if (!lvlToDim || lvlToDim.
isEmpty()) {
709 return parser.
getChecked<SparseTensorEncodingAttr>(
710 parser.
getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711 explicitVal, implicitVal, dimSlices);
715 auto map =
static_cast<AffineMap>(getDimToLvl());
719 printer <<
"<{ map = ";
720 printSymbols(map, printer);
722 printDimensions(map, printer, getDimSlices());
724 printLevels(map, printer, getLvlTypes());
728 printer <<
", posWidth = " << getPosWidth();
730 printer <<
", crdWidth = " << getCrdWidth();
731 if (getExplicitVal()) {
732 printer <<
", explicitVal = " << getExplicitVal();
734 if (getImplicitVal())
735 printer <<
", implicitVal = " << getImplicitVal();
739 void SparseTensorEncodingAttr::printSymbols(
AffineMap &map,
744 for (
unsigned i = 0, n = map.
getNumSymbols() - 1; i < n; i++)
745 printer <<
's' << i <<
", ";
751 void SparseTensorEncodingAttr::printDimensions(
754 if (!dimSlices.empty()) {
755 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
756 printer <<
'd' << i <<
" : " << dimSlices[i] <<
", ";
758 printer <<
'd' << map.
getNumDims() - 1 <<
" : "
762 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
763 printer <<
'd' << i <<
", ";
771 for (
unsigned i = 0, n = map.
getNumResults() - 1; i < n; i++) {
788 return emitError() <<
"unexpected position bitwidth: " << posWidth;
790 return emitError() <<
"unexpected coordinate bitwidth: " << crdWidth;
794 while (it != lvlTypes.end()) {
795 if (it == lvlTypes.begin() ||
797 return emitError() <<
"expected compressed or loose_compressed level "
798 "before singleton level";
800 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(),
isSingletonLT);
802 return emitError() <<
"expected all singleton lvlTypes "
803 "following a singleton level";
805 if (!std::all_of(it, curCOOEnd, [it](
LevelType i) {
809 return emitError() <<
"expected all singleton lvlTypes stored in the "
810 "same memory layout (SoA vs AoS).";
815 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
816 if (!std::all_of(lastBatch, lvlTypes.rend(),
isBatchLT))
817 return emitError() <<
"Batch lvlType can only be leading levels.";
820 auto soaLvls = llvm::make_filter_range(lvlTypes, [](
LevelType lt) {
823 if (llvm::any_of(soaLvls, [](
LevelType lt) {
826 return emitError() <<
"SoA is only applicable to singleton lvlTypes.";
830 if (
auto it = llvm::find_if(lvlTypes,
isNOutOfMLT);
831 it != std::end(lvlTypes)) {
832 if (it != lvlTypes.end() - 1)
833 return emitError() <<
"expected n_out_of_m to be the last level type";
834 if (!std::all_of(lvlTypes.begin(), it,
isDenseLT))
835 return emitError() <<
"expected all dense lvlTypes "
836 "before a n_out_of_m level";
840 <<
"expected 1xm block structure for n_out_of_m level";
843 unsigned coefficient = 0;
844 for (
const auto &elem : sizes) {
846 if (elem != coefficient && coefficient != 0) {
847 return emitError() <<
"expected only one blocked level "
848 "with the same coefficients";
853 if (coefficient !=
getM(*it)) {
854 return emitError() <<
"expected coeffiencts of Affine expressions "
855 "to be equal to m of n_out_of_m level";
864 const Level lvlRank = lvlTypes.size();
866 return emitError() <<
"expected a non-empty array for lvlTypes";
872 <<
"level-rank mismatch between dimToLvl and lvlTypes: "
877 return emitError() <<
"failed to infer lvlToDim from dimToLvl";
878 if (lvlToDim && (inferRes != lvlToDim))
879 return emitError() <<
"expected lvlToDim to be an inverse of dimToLvl";
880 if (dimRank > lvlRank)
881 return emitError() <<
"unexpected dimToLvl mapping from " << dimRank
882 <<
" to " << lvlRank;
884 if (!dimSlices.empty()) {
885 if (dimSlices.size() != dimRank)
887 <<
"dimension-rank mismatch between dimSlices and dimToLvl: "
888 << dimSlices.size() <<
" != " << dimRank;
891 if (dimRank != lvlRank)
893 <<
"dimSlices expected dimension-rank to match level-rank: "
894 << dimRank <<
" != " << lvlRank;
899 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
905 getPosWidth(), getCrdWidth(), getExplicitVal(),
906 getImplicitVal(), getDimSlices())))
911 const Dimension dimRank = dimShape.size();
913 return emitError() <<
"expected non-scalar sparse tensor";
914 if (getDimRank() != dimRank)
916 <<
"dimension-rank mismatch between encoding and tensor shape: "
917 << getDimRank() <<
" != " << dimRank;
918 if (
auto expVal = getExplicitVal()) {
919 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
920 if (attrType != elementType) {
921 return emitError() <<
"explicit value type mismatch between encoding and "
922 <<
"tensor element type: " << attrType
923 <<
" != " << elementType;
926 if (
auto impVal = getImplicitVal()) {
927 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
928 if (attrType != elementType) {
929 return emitError() <<
"implicit value type mismatch between encoding and "
930 <<
"tensor element type: " << attrType
931 <<
" != " << elementType;
934 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
935 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
936 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
937 if ((impFVal && impFVal.getValue().isNonZero()) ||
938 (impIntVal && !impIntVal.getValue().isZero()) ||
939 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
940 impComplexVal.getReal().isNonZero()))) {
941 return emitError() <<
"implicit value must be zero";
947 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart()
const {
949 assert(coo.size() == 1 || coo.empty());
950 if (!coo.empty() && coo.front().isAoS()) {
951 return coo.front().lvlRange.first;
957 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments()
const {
959 if (getLvlRank() <= 1)
964 while (l < getLvlRank()) {
967 auto cur = lts.begin() + l;
968 auto end = std::find_if(cur + 1, lts.end(), [](
LevelType lt) {
969 return !lt.isa<LevelFormat::Singleton>();
971 unsigned cooLen = std::distance(cur, end);
977 ret.push_back(
COOSegment{std::make_pair(l, l + cooLen),
996 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
998 for (
Level l = startLvl + 1; l < lvlRank; ++l)
999 if (!isSingletonLvl(l))
1004 return !
isUnique || isUniqueLvl(lvlRank - 1);
1010 lvlTypes.reserve(lvlRank);
1017 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1023 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1024 getCrdWidth(), getExplicitVal(), getImplicitVal());
1032 SparseTensorEncodingAttr
1034 if (
auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1035 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1036 if (
auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1037 return mdtp.getEncoding();
1043 auto map =
static_cast<AffineMap>(dimToLvl);
1060 lvlExprs.reserve(numLvls);
1063 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1064 for (
unsigned i = 0, n = numLvls; i < n; i++) {
1066 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1069 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1070 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1071 "expected only one floordiv for each dimension");
1076 components.push_back(binOp.getRHS());
1078 lvlExprComponents[pos] = components;
1080 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1081 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1082 "expected floordiv before mod");
1087 assert(
false &&
"expected floordiv or mod");
1097 for (
auto &components : lvlExprComponents) {
1098 assert(components.second.size() == 3 &&
1099 "expected 3 components to build lvlExprs");
1104 lvlExprs.push_back(addOp);
1111 "expected dimToLvl to be block sparsity for calling getBlockSize");
1114 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1116 blockSize.push_back(
1117 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1120 blockSize.push_back(0);
1129 std::map<unsigned, int64_t> coeffientMap;
1130 bool hasBlock =
false;
1132 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1134 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1135 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1136 if (!dimOp || !conOp || conOp.getValue() <= 0)
1139 auto pos = dimOp.getPosition();
1142 auto [it, inserted] = coeffientMap.try_emplace(pos);
1146 it->second = conOp.getValue();
1149 auto it = coeffientMap.find(pos);
1150 if (it == coeffientMap.end())
1153 if (conOp.getValue() != it->second)
1159 }
else if (
auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1160 auto pos = dimOp.getPosition();
1162 if (!coeffientMap.try_emplace(pos, 0).second)
1172 auto hasNonIdentityMap = [](
Value v) {
1177 return llvm::any_of(op->
getOperands(), hasNonIdentityMap) ||
1178 llvm::any_of(op->
getResults(), hasNonIdentityMap);
1183 assert(enc.isPermutation() &&
"Non permutation map not supported");
1184 if (
const auto dimToLvl = enc.getDimToLvl())
1192 assert(enc.isPermutation() &&
"Non permutation map not supported");
1193 if (
const auto lvlToDim = enc.getLvlToDim())
1203 static SparseTensorEncodingAttr
1206 for (
auto lt : enc.getLvlTypes())
1210 enc.getContext(), lts,
1220 enc.getDimSlices());
1223 StorageSpecifierType
1228 StorageSpecifierType
1231 SparseTensorEncodingAttr encoding) {
1250 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1252 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1254 "redundant level argument for querying value memory size");
1257 const auto enc = md.getType().getEncoding();
1258 const Level lvlRank = enc.getLvlRank();
1260 if (mdKind == StorageSpecifierKind::DimOffset ||
1261 mdKind == StorageSpecifierKind::DimStride)
1263 return op->
emitError(
"requested slice data on non-slice tensor");
1265 if (mdKind != StorageSpecifierKind::ValMemSize) {
1267 return op->
emitError(
"missing level argument");
1269 const Level l = lvl.value();
1271 return op->
emitError(
"requested level is out of bounds");
1273 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1275 "requested position memory size on a singleton level");
1291 llvm_unreachable(
"Unrecognizable FieldKind");
1296 RankedTensorType valTp,
1299 return op->
emitError(
"the sparse-tensor must have static shape");
1301 return op->
emitError(
"the sparse-tensor must have an encoding attribute");
1307 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1309 unsigned expCOORank = stt.
getLvlRank() - cooStartLvl;
1310 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1311 return op->
emitError(
"input/output trailing COO level-ranks don't match");
1318 return op->
emitError(
"inconsistent number of fields between input/output");
1321 bool misMatch =
false;
1328 Type inputTp =
nullptr;
1332 assert(fid == idx && stt.
getLvlType(lvl) == lt);
1333 inputTp = lvlTps[idx++];
1336 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1338 if (inpElemTp != expElemTp) {
1346 return op->
emitError(
"input/output element-types don't match");
1351 RankedTensorType valuesTp = getValues().getType();
1352 const auto lvlsTp = getLevels().getTypes();
1359 return emitError(
"output values and return value type mismatch");
1361 for (
auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1362 if (ot.getType() != rt.getType())
1363 return emitError(
"output levels and return levels type mismatch");
1365 RankedTensorType valuesTp = getRetValues().getType();
1366 const auto lvlsTp = getRetLevels().getTypes();
1372 RankedTensorType tp1 = getSource().getType();
1373 RankedTensorType tp2 = getDest().getType();
1374 if (tp1.getRank() != tp2.getRank())
1375 return emitError(
"unexpected conversion mismatch in rank");
1377 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1378 if (dstEnc && dstEnc.isSlice())
1379 return emitError(
"cannot convert to a sparse tensor slice");
1381 auto shape1 = tp1.getShape();
1382 auto shape2 = tp2.getShape();
1386 for (
Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1387 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1388 return emitError(
"unexpected conversion mismatch in dimension ") << d;
1398 bool ConvertOp::needsExtraSort() {
1417 if (
auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1418 if (isa<SparseElementsAttr>(constOp.getValue()))
1425 uint64_t inRank = getEncoder().getLvlRank();
1426 uint64_t outRank = getEncoder().getDimRank();
1428 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1429 std::swap(inRank, outRank);
1431 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1432 return emitError(
"Coordinate rank mismatch with encoding");
1437 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1439 if (getEncoder().isIdentity()) {
1440 results.assign(getInCrds().begin(), getInCrds().end());
1444 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1445 ? getEncoder().getDimToLvl()
1446 : getEncoder().getLvlToDim();
1448 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1453 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1454 bool sameDef = def && llvm::all_of(getInCrds(), [def](
Value v) {
1460 bool oppositeDir = def.getDirection() != getDirection();
1462 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1463 bool sameCount = def.getNumResults() == getInCrds().size();
1464 if (!oppositeDir || !sameOracle || !sameCount)
1469 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1470 [](
auto valuePair) {
1471 auto [lhs, rhs] = valuePair;
1479 results.append(def.getInCrds().begin(), def.getInCrds().end());
1486 return build(builder, state, source, val);
1490 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1492 if (
static_cast<uint64_t
>(lvl.value()) >= stt.
getLvlRank())
1494 "Level index exceeds the rank of the input sparse tensor");
1499 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1509 cast<RankedTensorType>(getSource().
getType()).getRank());
1514 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1518 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1528 auto getIndexAttr = [
this](int64_t lvlSz) {
1533 if (ShapedType::isStatic(lvlShape[lvl]))
1534 return getIndexAttr(lvlShape[lvl]);
1540 SparseTensorEncodingAttr dstEnc,
Value source) {
1544 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1547 return build(odsBuilder, odsState, dstTp, source);
1556 if (srcLvlTps.size() != dstLvlTps.size())
1557 return emitError(
"Level rank mismatch between source/dest tensors");
1559 for (
auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1560 if (srcLvlTp != dstLvlTp)
1561 return emitError(
"Level type mismatch between source/dest tensors");
1565 return emitError(
"Crd/Pos width mismatch between source/dest tensors");
1569 return emitError(
"Element type mismatch between source/dest tensors");
1573 for (
auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1574 if (srcLvlSz != dstLvlSz) {
1578 return emitError(
"Level size mismatch between source/dest tensors");
1585 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1589 if (
auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1591 if (def.getSource().getType() == getDest().
getType())
1592 return def.getSource();
1597 template <
typename ToBufferOp>
1602 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1604 Type elemTp =
nullptr;
1605 bool withStride =
false;
1606 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1608 }
else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1609 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1611 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1613 }
else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1617 assert(elemTp &&
"unhandled operation.");
1619 bufShape.push_back(ShapedType::kDynamic);
1623 {ShapedType::kDynamic})
1624 : StridedLayoutAttr();
1632 return emitError(
"requested level is out of bounds");
1634 return emitError(
"unexpected type for positions");
1639 ToPositionsOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1643 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1649 return emitError(
"requested level is out of bounds");
1651 return emitError(
"unexpected type for coordinates");
1656 ToCoordinatesOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1660 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1666 return emitError(
"expected sparse tensor with a COO region");
1670 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1674 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1682 return emitError(
"unexpected mismatch in element types");
1686 LogicalResult ToValuesOp::inferReturnTypes(
MLIRContext *ctx,
1687 std::optional<Location> loc,
1692 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1696 auto rank =
getSlice().getType().getRank();
1697 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1698 return emitError(
"requested dimension out of bound");
1703 auto rank =
getSlice().getType().getRank();
1704 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1705 return emitError(
"requested dimension out of bound");
1711 getSpecifier(), getOperation());
1714 template <
typename SpecifierOp>
1716 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1719 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1720 const StorageSpecifierKind
kind = getSpecifierKind();
1721 const auto lvl = getLevel();
1723 if (
kind == op.getSpecifierKind() && lvl == op.getLevel())
1724 return op.getValue();
1730 getSpecifier(), getOperation());
1735 const char *regionName,
1738 unsigned expectedNum = inputTypes.size();
1739 if (numArgs != expectedNum)
1740 return op->emitError() << regionName <<
" region must have exactly "
1741 << expectedNum <<
" arguments";
1743 for (
unsigned i = 0; i < numArgs; i++) {
1745 if (typ != inputTypes[i])
1746 return op->emitError() << regionName <<
" region argument " << (i + 1)
1747 <<
" type mismatch";
1750 YieldOp yield = dyn_cast<YieldOp>(term);
1752 return op->emitError() << regionName
1753 <<
" region must end with sparse_tensor.yield";
1754 if (!yield.hasSingleResult() ||
1755 yield.getSingleResult().getType() != outputType)
1756 return op->emitError() << regionName <<
" region yield type mismatch";
1763 Type leftType = getX().getType();
1764 Type rightType = getY().getType();
1765 Type outputType = getOutput().getType();
1766 Region &overlap = getOverlapRegion();
1767 Region &left = getLeftRegion();
1768 Region &right = getRightRegion();
1772 if (!overlap.
empty()) {
1774 TypeRange{leftType, rightType}, outputType)))
1777 if (!left.
empty()) {
1781 }
else if (getLeftIdentity()) {
1782 if (leftType != outputType)
1783 return emitError(
"left=identity requires first argument to have the same "
1784 "type as the output");
1786 if (!right.
empty()) {
1790 }
else if (getRightIdentity()) {
1791 if (rightType != outputType)
1792 return emitError(
"right=identity requires second argument to have the "
1793 "same type as the output");
1799 Type inputType = getX().getType();
1800 Type outputType = getOutput().getType();
1804 Region &present = getPresentRegion();
1805 if (!present.
empty()) {
1810 Region &absent = getAbsentRegion();
1811 if (!absent.
empty()) {
1817 Block *parent = getOperation()->getBlock();
1819 cast<YieldOp>(absentBlock->
getTerminator()).getSingleResult();
1820 if (
auto arg = dyn_cast<BlockArgument>(absentVal)) {
1821 if (arg.getOwner() == parent)
1822 return emitError(
"absent region cannot yield linalg argument");
1824 if (!isa<arith::ConstantOp>(def) &&
1825 (def->getBlock() == absentBlock || def->getBlock() == parent))
1826 return emitError(
"absent region cannot yield locally computed value");
1832 bool ConcatenateOp::needsExtraSort() {
1837 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](
Value op) {
1844 bool directLowerable =
1845 allSameOrdered && getDimension() == 0 && dstStt.
isIdentity();
1846 return !directLowerable;
1851 const Dimension concatDim = getDimension();
1852 const Dimension dimRank = dstTp.getDimRank();
1854 if (getInputs().size() <= 1)
1855 return emitError(
"Need at least two tensors to concatenate.");
1857 if (concatDim >= dimRank)
1859 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1860 concatDim, dimRank));
1863 const auto i = it.index();
1865 if (srcTp.hasDynamicDimShape())
1866 return emitError(llvm::formatv(
"Input tensor ${0} has dynamic shape", i));
1867 const Dimension srcDimRank = srcTp.getDimRank();
1868 if (srcDimRank != dimRank)
1870 llvm::formatv(
"Input tensor ${0} has a different rank (rank={1}) "
1871 "from the output tensor (rank={2}).",
1872 i, srcDimRank, dimRank));
1875 for (
Dimension d = 0; d < dimRank; d++) {
1876 const Size dstSh = dstTp.getDimShape()[d];
1877 if (d == concatDim) {
1878 if (ShapedType::isStatic(dstSh)) {
1883 for (
const auto src : getInputs())
1889 "The concatenation dimension of the output tensor should be the "
1890 "sum of all the concatenation dimensions of the input tensors.");
1894 for (
const auto src : getInputs()) {
1896 if (ShapedType::isStatic(prev) && sh != prev)
1897 return emitError(
"All dimensions (expect for the concatenating one) "
1898 "should be equal.");
1909 build(builder, result, curSize, inBuffer, value,
Value());
1915 if (nValue && nValue.value() < 1)
1916 return emitOpError(
"n must be not less than 1");
1923 if (stt.
getLvlRank() != 1 +
static_cast<Level>(getLvlCoords().size()))
1924 return emitOpError(
"incorrect number of coordinates");
1928 void ForeachOp::build(
1933 build(builder, result, initArgs.
getTypes(), tensor, initArgs, order);
1945 blockArgTypes.append(initArgs.
getTypes().begin(), initArgs.
getTypes().end());
1950 auto ®ion = *result.
regions.front();
1952 builder.
createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1953 bodyBuilder(builder, result.
location,
1961 const Dimension dimRank = t.getDimRank();
1962 const auto args = getBody()->getArguments();
1964 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1965 return emitError(
"Level traverse order does not match tensor's level rank");
1967 if (dimRank + 1 + getInitArgs().size() != args.size())
1968 return emitError(
"Unmatched number of arguments in the block");
1970 if (getNumResults() != getInitArgs().size())
1971 return emitError(
"Mismatch in number of init arguments and results");
1973 if (getResultTypes() != getInitArgs().getTypes())
1974 return emitError(
"Mismatch in types of init arguments and results");
1977 auto yield = cast<YieldOp>(getBody()->getTerminator());
1978 if (yield.getNumOperands() != getNumResults() ||
1979 yield.getOperands().getTypes() != getResultTypes())
1980 return emitError(
"Mismatch in types of yield values and results");
1986 llvm::formatv(
"Expecting Index type for argument at index {0}", d));
1988 const auto elemTp = t.getElementType();
1989 const auto valueTp = args[dimRank].getType();
1990 if (elemTp != valueTp)
1992 llvm::formatv(
"Unmatched element type between input tensor and "
1993 "block argument, expected:{0}, got: {1}",
2001 return getInputCoo();
2011 return emitError(
"Expected COO sparse tensors only");
2014 return emitError(
"Unmatched dim2lvl map between input and result COO");
2019 return emitError(
"Unmatched storage format between input and result COO");
2025 Type inputType = getX().getType();
2026 Region &formula = getRegion();
2028 TypeRange{inputType, inputType}, inputType);
2033 Type inputType = getX().getType();
2034 Type boolType = b.getI1Type();
2035 Region &formula = getRegion();
2044 return emitError(llvm::formatv(
"Expected rank(perm_map) > 1, got {0}", nx));
2048 llvm::formatv(
"Expected a permutation map, got {0}", xPerm));
2057 const auto checkDim = [&](
Value v,
Size minSize,
2058 const char *message) -> LogicalResult {
2060 if (ShapedType::isStatic(sh) && sh < minSize)
2062 llvm::formatv(
"{0} got {1} < {2}", message, sh, minSize));
2065 uint64_t n = cn.value();
2067 if (
auto nyAttr = getNyAttr())
2068 ny = nyAttr.getInt();
2069 if (
failed(checkDim(getXy(), n * (nx + ny),
2070 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2072 for (
Value opnd : getYs())
2073 if (
failed(checkDim(opnd, n,
"Expected dimension(y) >= n")))
2083 IterSpaceType IteratorType::getIterSpaceType()
const {
2088 IteratorType IterSpaceType::getIteratorType()
const {
2108 "expect larger level upper bound than lower bound");
2116 IntegerAttr &lvlHiAttr) {
2133 p << lo <<
" to " << hi;
2139 IntegerAttr lvlHi) {
2140 unsigned lo = lvlLo.getValue().getZExtValue();
2141 unsigned hi = lvlHi.getValue().getZExtValue();
2155 ParseResult crdList =
2158 if (parser.parseArgument(definedArgs.emplace_back()))
2160 definedSet.set(cnt);
2168 "parsed more value than expected.");
2173 "expecting SSA value or \"_\" for level coordinates");
2175 assert(definedArgs.size() == definedSet.
count());
2182 if (definedSet.
empty())
2185 for (
unsigned i = 0; i < size; i++) {
2186 if (definedSet[i]) {
2187 p << blocksArgs.front();
2188 blocksArgs = blocksArgs.drop_front();
2195 assert(blocksArgs.empty());
2208 for (
auto &coord : coords)
2212 state.addAttribute(
"crdUsedLvls",
2229 if (iterators.size() != spaces.size())
2232 "mismatch in number of sparse iterators and sparse spaces");
2237 size_t numCrds = coords.size();
2245 blockArgs.append(coords);
2251 if (iterSpaceTps.size() != spaces.size())
2253 "mismatch in number of iteration space operands "
2254 "and iteration space types");
2256 for (
auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2257 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2260 "expected sparse_tensor.iter_space type for "
2261 "iteration space operands");
2262 it.type = spaceTp.getIteratorType();
2277 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2280 "mismatch in number of iteration arguments and return values");
2283 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2305 size_t numCrds = coords.size();
2313 blockArgs.append(coords);
2321 if (iterSpaceTps.size() != spaces.size())
2323 "mismatch in number of iteration space operands "
2324 "and iteration space types");
2334 state.operands.append(spacesVals);
2339 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2342 "mismatch in number of iteration arguments and return values");
2345 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2354 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2359 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2362 adaptor.getHiLvl()));
2367 if (getLoLvl() >= getHiLvl())
2368 return emitOpError(
"expected smaller level low than level high");
2371 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2373 "parent iterator should be specified iff level lower bound equals 0");
2377 IterSpaceType spaceTp = getExtractedSpace().getType();
2378 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2380 "mismatch in parent iterator encoding and iteration space encoding.");
2382 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2383 return emitOpError(
"parent iterator should be used to extract an "
2384 "iteration space from a consecutive level.");
2392 auto itTp = getIterator().getType();
2395 return emitOpError(
"mismatch in tensor encoding and iterator encoding.");
2398 return emitOpError(
"must use last-level iterator to extract values. ");
2409 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2410 for (
unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2411 if (
auto crd = iterateOp.getLvlCrd(i)) {
2412 if (crd->getUsers().empty())
2413 toRemove.set(crd->getArgNumber());
2420 if (toRemove.none())
2424 iterateOp.setCrdUsedLvls(newUsedLvls);
2425 iterateOp.getBody()->eraseArguments(toRemove);
2438 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.
getType()).getSpaceDim();
2441 return build(builder, odsState, iterSpace, initArgs, set);
2458 for (
Value v : initArgs)
2462 for (
unsigned i = 0, e = crdUsedLvls.
count(); i < e; i++)
2467 llvm::cast<IterSpaceType>(iterSpace.
getType()).getIteratorType(),
2478 if (iters.size() != 1)
2480 "expected only one iterator/iteration space");
2482 iterArgs.append(iters);
2503 StringRef prefix =
"") {
2504 assert(blocksArgs.size() == initializers.size() &&
2505 "expected same length of arguments and initializers");
2506 if (initializers.empty())
2510 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
2511 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2516 template <
typename SparseLoopOp>
2518 if (op.getInitArgs().size() != op.getNumResults()) {
2519 return op.emitOpError(
2520 "mismatch in number of loop-carried values and defined values");
2522 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2523 return op.emitOpError(
"required out-of-bound coordinates");
2532 p <<
" " << getIterator() <<
" in " << getIterSpace();
2533 if (!getCrdUsedLvls().empty()) {
2540 p <<
" : " << getIterSpace().getType() <<
" ";
2541 if (!getInitArgs().empty())
2546 !getInitArgs().empty());
2549 LogicalResult IterateOp::verifyRegions() {
2550 if (getIterator().
getType() != getIterSpace().
getType().getIteratorType())
2551 return emitOpError(
"mismatch in iterator and iteration space type");
2552 if (getNumRegionIterArgs() != getNumResults())
2554 "mismatch in number of basic block args and defined values");
2556 auto initArgs = getInitArgs();
2557 auto iterArgs = getRegionIterArgs();
2558 auto yieldVals = getYieldedValues();
2559 auto opResults = getResults();
2560 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2561 opResults.size()})) {
2562 return emitOpError() <<
"number mismatch between iter args and results.";
2565 for (
auto [i, init, iter, yield, ret] :
2567 if (init.getType() != ret.getType())
2568 return emitOpError() <<
"types mismatch between " << i
2569 <<
"th iter operand and defined value";
2570 if (iter.getType() != ret.getType())
2571 return emitOpError() <<
"types mismatch between " << i
2572 <<
"th iter region arg and defined value";
2573 if (yield.getType() != ret.getType())
2574 return emitOpError() <<
"types mismatch between " << i
2575 <<
"th yield value and defined value";
2585 return getInitArgsMutable();
2589 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2592 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2593 return cast<sparse_tensor::YieldOp>(
2594 getRegion().getBlocks().front().getTerminator())
2595 .getResultsMutable();
2598 std::optional<ResultRange> IterateOp::getLoopResults() {
return getResults(); }
2601 return getInitArgs();
2608 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2615 unsigned numCases) {
2617 cast<IterSpaceType>(iterSpaces.front().
getType()).getSpaceDim();
2626 return CoIterateOp::build(builder, odsState, initArgs.
getTypes(), iterSpaces,
2627 initArgs, set, cases,
2642 {static_cast<int32_t>(spaces.size()),
2643 static_cast<int32_t>(result.types.size())}));
2658 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].
getType());
2659 definedIts[i].type = spaceTp.getIteratorType();
2661 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2680 llvm::interleaveComma(getIterSpaces(), p, [&](
auto s) { p << s; });
2683 if (!getCrdUsedLvls().empty()) {
2691 p <<
" : (" << getIterSpaces().getTypes() <<
")";
2692 if (!getInitArgs().empty())
2693 p.printArrowTypeList(getInitArgs().getTypes());
2695 for (
unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2699 getRegionDefinedSpace(idx));
2701 p.printRegion(getRegion(idx),
false,
2702 !getInitArgs().empty());
2706 ValueRange CoIterateOp::getYieldedValues(
unsigned regionIdx) {
2707 return cast<sparse_tensor::YieldOp>(
2708 getRegion(regionIdx).getBlocks().front().getTerminator())
2712 LogicalResult CoIterateOp::verifyRegions() {
2713 for (
unsigned r = 0, e = getNumRegions(); r < e; r++) {
2714 if (getNumRegionIterArgs() != getNumResults())
2716 "mismatch in number of basic block args and defined values");
2718 auto initArgs = getInitArgs();
2719 auto iterArgs = getRegionIterArgs(r);
2720 auto yieldVals = getYieldedValues(r);
2721 auto opResults = getResults();
2722 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2723 opResults.size()})) {
2724 return emitOpError()
2725 <<
"number mismatch between iter args and results on " << r
2729 for (
auto [i, init, iter, yield, ret] :
2731 if (init.getType() != ret.getType())
2732 return emitOpError()
2733 <<
"types mismatch between " << i
2734 <<
"th iter operand and defined value on " << r <<
"th region";
2735 if (iter.getType() != ret.getType())
2736 return emitOpError() <<
"types mismatch between " << i
2737 <<
"th iter region arg and defined value on " << r
2739 if (yield.getType() != ret.getType())
2740 return emitOpError()
2741 <<
"types mismatch between " << i
2742 <<
"th yield value and defined value on " << r <<
"th region";
2746 auto cases = getRegionDefinedSpaces();
2747 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2748 if (set.size() != getNumRegions())
2749 return emitOpError(
"contains duplicated cases.");
2756 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2757 for (
Region &r : getCaseRegions())
2758 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2773 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2778 void SparseTensorDialect::initialize() {
2780 #define GET_ATTRDEF_LIST
2781 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2784 #define GET_TYPEDEF_LIST
2785 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2789 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2791 declarePromisedInterfaces<
2792 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2793 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2794 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2797 #define GET_OP_CLASSES
2798 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2800 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool isUnique(It begin, It end)
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
MLIRContext * getContext() const
Type getElementType() const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
llvm::hash_code hash_value(LevelType lt)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const