26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/FormatVariadic.h"
29 #define GET_ATTRDEF_CLASSES
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
41 #define GET_TYPEDEF_CLASSES
42 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
79 if (dimShape.has_value()) {
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
88 memrefShape.push_back(ShapedType::kDynamic);
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
111 for (
Level l = 0; l < lvlRank; ) {
112 const auto lt = lvlTypes[l];
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
125 l = cooSegsRef.front().lvlRange.second;
131 cooSegsRef = cooSegsRef.drop_front();
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
179 llvm_unreachable(
"unrecognized field kind");
184 unsigned numFields = 0;
194 unsigned numFields = 0;
206 std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl)
const {
212 assert(lvl.has_value());
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217 stride = lvlRank - cooStart;
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
239 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(
static_cast<uint64_t
>(v));
244 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset()
const {
245 return getStatic(getOffset());
248 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride()
const {
249 return getStatic(getStride());
252 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize()
const {
253 return getStatic(getSize());
256 bool SparseTensorDimSliceAttr::isCompletelyDynamic()
const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
261 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ?
"?" : std::to_string(v);
266 assert(getImpl() &&
"Uninitialized SparseTensorDimSliceAttr");
268 os << getStaticString(getOffset());
270 os << getStaticString(getSize());
272 os << getStaticString(getStride());
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
287 "expect positive value or ? for slice offset/size/stride");
290 return parseResult.value();
294 result = SparseTensorDimSliceAttr::kDynamic;
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
311 offset, size, stride);
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() <<
"expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() <<
"expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() <<
"expect positive value or ? for slice stride";
326 SparseTensorEncodingAttr
327 SparseTensorEncodingAttr::withDimToLvl(
AffineMap dimToLvl)
const {
328 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
334 SparseTensorEncodingAttr
335 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc)
const {
336 return withDimToLvl(enc ? enc.getDimToLvl() :
AffineMap());
339 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl()
const {
343 SparseTensorEncodingAttr
344 SparseTensorEncodingAttr::withBitWidths(
unsigned posWidth,
345 unsigned crdWidth)
const {
346 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
352 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths()
const {
353 return withBitWidths(0, 0);
356 SparseTensorEncodingAttr
357 SparseTensorEncodingAttr::withExplicitVal(
Attribute explicitVal)
const {
358 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
364 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal()
const {
368 SparseTensorEncodingAttr
369 SparseTensorEncodingAttr::withImplicitVal(
Attribute implicitVal)
const {
370 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
376 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal()
const {
380 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices()
const {
391 uint64_t SparseTensorEncodingAttr::getBatchLvlRank()
const {
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
398 return !getImpl() || llvm::all_of(getLvlTypes(),
isDenseLT);
401 bool SparseTensorEncodingAttr::isAllOrdered()
const {
402 return !getImpl() || llvm::all_of(getLvlTypes(),
isOrderedLT);
405 Type SparseTensorEncodingAttr::getCrdElemType()
const {
413 Type SparseTensorEncodingAttr::getPosElemType()
const {
421 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
427 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
433 bool SparseTensorEncodingAttr::isIdentity()
const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441 Dimension SparseTensorEncodingAttr::getDimRank()
const {
442 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.
getNumDims() : getLvlRank();
447 Level SparseTensorEncodingAttr::getLvlRank()
const {
448 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
455 assert(l < getLvlRank() &&
"Level is out of bounds");
456 return getLvlTypes()[l];
459 bool SparseTensorEncodingAttr::isSlice()
const {
460 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
464 SparseTensorDimSliceAttr
465 SparseTensorEncodingAttr::getDimSlice(
Dimension dim)
const {
466 assert(isSlice() &&
"Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() &&
"Dimension is out of bounds");
469 return dimSlices[dim];
472 std::optional<uint64_t>
473 SparseTensorEncodingAttr::getStaticDimSliceOffset(
Dimension dim)
const {
474 return getDimSlice(dim).getStaticOffset();
477 std::optional<uint64_t>
478 SparseTensorEncodingAttr::getStaticDimSliceStride(
Dimension dim)
const {
479 return getDimSlice(dim).getStaticStride();
482 std::optional<uint64_t>
483 SparseTensorEncodingAttr::getStaticLvlSliceOffset(
Level lvl)
const {
484 return getStaticDimSliceOffset(
toDim(*
this, lvl));
487 std::optional<uint64_t>
488 SparseTensorEncodingAttr::getStaticLvlSliceStride(
Level lvl)
const {
489 return getStaticDimSliceStride(
toDim(*
this, lvl));
494 CrdTransDirectionKind dir)
const {
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
504 for (
unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ?
toDim(*
this, r)
507 ret.push_back(srcShape[trans]);
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (!ShapedType::isDynamic(sz)) {
533 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534 ret.push_back(c.getValue() + 1);
536 if (
auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
540 if (
auto bound = llvm::dyn_cast<AffineConstantExpr>(
mod.getRHS())) {
541 ret.push_back(bound.getValue());
545 ret.push_back(ShapedType::kDynamic);
548 assert(ret.size() == rank);
555 CrdTransDirectionKind dir)
const {
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
562 auto transOp = builder.
create<CrdTranslateOp>(loc, retType, crds, dir, *
this);
563 return transOp.getOutCrds();
578 unsigned posWidth = 0;
579 unsigned crdWidth = 0;
584 "explicitVal",
"implicitVal"};
587 auto *it = find(keys, attrName);
588 if (it == keys.end()) {
592 unsigned keyWordIndex = it - keys.begin();
597 switch (keyWordIndex) {
600 auto res = cParser.parseDimLvlMap();
603 const auto &dlm = *res;
605 const Level lvlRank = dlm.getLvlRank();
606 for (
Level lvl = 0; lvl < lvlRank; lvl++)
607 lvlTypes.push_back(dlm.getLvlType(lvl));
609 const Dimension dimRank = dlm.getDimRank();
610 for (
Dimension dim = 0; dim < dimRank; dim++)
611 dimSlices.push_back(dlm.getDimSlice(dim));
615 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
616 return static_cast<bool>(slice.getImpl());
618 if (llvm::any_of(dimSlices, isDefined)) {
619 const auto defaultSlice =
621 for (
Dimension dim = 0; dim < dimRank; dim++)
622 if (!isDefined(dimSlices[dim]))
623 dimSlices[dim] = defaultSlice;
628 dimToLvl = dlm.getDimToLvlMap(parser.
getContext());
629 lvlToDim = dlm.getLvlToDimMap(parser.
getContext());
636 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639 "expected an integral position bitwidth");
642 posWidth = intAttr.getInt();
649 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652 "expected an integral index bitwidth");
655 crdWidth = intAttr.getInt();
662 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
663 explicitVal = result;
664 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
665 explicitVal = result;
668 "expected a numeric value for explicitVal");
677 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
678 implicitVal = result;
679 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
680 implicitVal = result;
683 "expected a numeric value for implicitVal");
701 if (!lvlToDim || lvlToDim.
isEmpty()) {
704 return parser.
getChecked<SparseTensorEncodingAttr>(
705 parser.
getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
706 explicitVal, implicitVal, dimSlices);
710 auto map =
static_cast<AffineMap>(getDimToLvl());
714 printer <<
"<{ map = ";
715 printSymbols(map, printer);
717 printDimensions(map, printer, getDimSlices());
719 printLevels(map, printer, getLvlTypes());
723 printer <<
", posWidth = " << getPosWidth();
725 printer <<
", crdWidth = " << getCrdWidth();
726 if (getExplicitVal()) {
727 printer <<
", explicitVal = " << getExplicitVal();
729 if (getImplicitVal())
730 printer <<
", implicitVal = " << getImplicitVal();
734 void SparseTensorEncodingAttr::printSymbols(
AffineMap &map,
739 for (
unsigned i = 0, n = map.
getNumSymbols() - 1; i < n; i++)
740 printer <<
's' << i <<
", ";
746 void SparseTensorEncodingAttr::printDimensions(
749 if (!dimSlices.empty()) {
750 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
751 printer <<
'd' << i <<
" : " << dimSlices[i] <<
", ";
753 printer <<
'd' << map.
getNumDims() - 1 <<
" : "
757 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
758 printer <<
'd' << i <<
", ";
766 for (
unsigned i = 0, n = map.
getNumResults() - 1; i < n; i++) {
783 return emitError() <<
"unexpected position bitwidth: " << posWidth;
785 return emitError() <<
"unexpected coordinate bitwidth: " << crdWidth;
786 if (
auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(),
isSingletonLT);
787 it != std::end(lvlTypes)) {
788 if (it == lvlTypes.begin() ||
790 return emitError() <<
"expected compressed or loose_compressed level "
791 "before singleton level";
792 if (!std::all_of(it, lvlTypes.end(),
793 [](
LevelType i) { return isSingletonLT(i); }))
794 return emitError() <<
"expected all singleton lvlTypes "
795 "following a singleton level";
797 if (!std::all_of(it, lvlTypes.end(), [it](
LevelType i) {
798 return it->isa<LevelPropNonDefault::SoA>() ==
799 i.isa<LevelPropNonDefault::SoA>();
801 return emitError() <<
"expected all singleton lvlTypes stored in the "
802 "same memory layout (SoA vs AoS).";
806 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
807 if (!std::all_of(lastBatch, lvlTypes.rend(),
isBatchLT))
808 return emitError() <<
"Batch lvlType can only be leading levels.";
811 auto soaLvls = llvm::make_filter_range(lvlTypes, [](
LevelType lt) {
814 if (llvm::any_of(soaLvls, [](
LevelType lt) {
817 return emitError() <<
"SoA is only applicable to singleton lvlTypes.";
821 if (
auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(),
isNOutOfMLT);
822 it != std::end(lvlTypes)) {
823 if (it != lvlTypes.end() - 1)
824 return emitError() <<
"expected n_out_of_m to be the last level type";
825 if (!std::all_of(lvlTypes.begin(), it,
826 [](
LevelType i) { return isDenseLT(i); }))
827 return emitError() <<
"expected all dense lvlTypes "
828 "before a n_out_of_m level";
832 <<
"expected 1xm block structure for n_out_of_m level";
835 unsigned coefficient = 0;
836 for (
const auto &elem : sizes) {
838 if (elem != coefficient && coefficient != 0) {
839 return emitError() <<
"expected only one blocked level "
840 "with the same coefficients";
845 if (coefficient !=
getM(*it)) {
846 return emitError() <<
"expected coeffiencts of Affine expressions "
847 "to be equal to m of n_out_of_m level";
856 const Level lvlRank = lvlTypes.size();
858 return emitError() <<
"expected a non-empty array for lvlTypes";
864 <<
"level-rank mismatch between dimToLvl and lvlTypes: "
869 return emitError() <<
"failed to infer lvlToDim from dimToLvl";
870 if (lvlToDim && (inferRes != lvlToDim))
871 return emitError() <<
"expected lvlToDim to be an inverse of dimToLvl";
872 if (dimRank > lvlRank)
873 return emitError() <<
"unexpected dimToLvl mapping from " << dimRank
874 <<
" to " << lvlRank;
876 if (!dimSlices.empty()) {
877 if (dimSlices.size() != dimRank)
879 <<
"dimension-rank mismatch between dimSlices and dimToLvl: "
880 << dimSlices.size() <<
" != " << dimRank;
883 if (dimRank != lvlRank)
885 <<
"dimSlices expected dimension-rank to match level-rank: "
886 << dimRank <<
" != " << lvlRank;
897 getPosWidth(), getCrdWidth(), getExplicitVal(),
898 getImplicitVal(), getDimSlices())))
903 const Dimension dimRank = dimShape.size();
905 return emitError() <<
"expected non-scalar sparse tensor";
906 if (getDimRank() != dimRank)
908 <<
"dimension-rank mismatch between encoding and tensor shape: "
909 << getDimRank() <<
" != " << dimRank;
921 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
923 for (
Level l = startLvl + 1; l < lvlRank; ++l)
924 if (!isSingletonLvl(l))
929 return !
isUnique || isUniqueLvl(lvlRank - 1);
934 assert(coo.size() == 1 || coo.empty());
935 if (!coo.empty() && coo.front().isAoS()) {
936 return coo.front().lvlRange.first;
944 if (!hasEncoding() || lvlRank <= 1)
949 while (l < lvlRank) {
952 auto cur = lts.begin() + l;
953 auto end = std::find_if(cur + 1, lts.end(), [](
LevelType lt) {
954 return !lt.isa<LevelFormat::Singleton>();
956 unsigned cooLen = std::distance(cur, end);
962 ret.push_back(
COOSegment{std::make_pair(l, l + cooLen),
976 lvlTypes.reserve(lvlRank);
983 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
989 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
990 getCrdWidth(), getExplicitVal(), getImplicitVal());
998 SparseTensorEncodingAttr
1000 if (
auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1001 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1002 if (
auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1003 return mdtp.getEncoding();
1009 auto map =
static_cast<AffineMap>(dimToLvl);
1026 lvlExprs.reserve(numLvls);
1029 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1030 for (
unsigned i = 0, n = numLvls; i < n; i++) {
1032 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1035 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1036 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1037 "expected only one floordiv for each dimension");
1042 components.push_back(binOp.getRHS());
1044 lvlExprComponents[pos] = components;
1046 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1047 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1048 "expected floordiv before mod");
1053 assert(
false &&
"expected floordiv or mod");
1063 for (
auto &components : lvlExprComponents) {
1064 assert(components.second.size() == 3 &&
1065 "expected 3 components to build lvlExprs");
1070 lvlExprs.push_back(addOp);
1077 "expected dimToLvl to be block sparsity for calling getBlockSize");
1080 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1082 blockSize.push_back(
1083 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1086 blockSize.push_back(0);
1095 std::map<unsigned, int64_t> coeffientMap;
1096 bool hasBlock =
false;
1098 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1100 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1101 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1102 if (!dimOp || !conOp || conOp.getValue() <= 0)
1105 auto pos = dimOp.getPosition();
1108 if (coeffientMap.find(pos) != coeffientMap.end())
1111 coeffientMap[pos] = conOp.getValue();
1114 if (coeffientMap.find(pos) == coeffientMap.end())
1117 if (conOp.getValue() != coeffientMap[pos])
1123 }
else if (
auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1124 auto pos = dimOp.getPosition();
1126 if (coeffientMap.find(pos) != coeffientMap.end())
1128 coeffientMap[pos] = 0;
1137 auto hasNonIdentityMap = [](
Value v) {
1142 return llvm::any_of(op->
getOperands(), hasNonIdentityMap) ||
1143 llvm::any_of(op->
getResults(), hasNonIdentityMap);
1148 assert(enc.isPermutation() &&
"Non permutation map not supported");
1149 if (
const auto dimToLvl = enc.getDimToLvl())
1157 assert(enc.isPermutation() &&
"Non permutation map not supported");
1158 if (
const auto lvlToDim = enc.getLvlToDim())
1168 static SparseTensorEncodingAttr
1171 for (
auto lt : enc.getLvlTypes())
1175 enc.getContext(), lts,
1185 enc.getDimSlices());
1188 StorageSpecifierType
1207 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1209 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1211 "redundant level argument for querying value memory size");
1214 const auto enc = md.getType().getEncoding();
1215 const Level lvlRank = enc.getLvlRank();
1217 if (mdKind == StorageSpecifierKind::DimOffset ||
1218 mdKind == StorageSpecifierKind::DimStride)
1220 return op->
emitError(
"requested slice data on non-slice tensor");
1222 if (mdKind != StorageSpecifierKind::ValMemSize) {
1224 return op->
emitError(
"missing level argument");
1226 const Level l = lvl.value();
1228 return op->
emitError(
"requested level is out of bounds");
1230 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1232 "requested position memory size on a singleton level");
1248 llvm_unreachable(
"Unrecognizable FieldKind");
1253 RankedTensorType valTp,
1256 return op->
emitError(
"the sparse-tensor must have static shape");
1258 return op->
emitError(
"the sparse-tensor must have an encoding attribute");
1264 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1266 unsigned expCOORank = stt.
getLvlRank() - cooStartLvl;
1267 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1268 op->
emitError(
"input/output trailing COO level-ranks don't match");
1275 return op->
emitError(
"inconsistent number of fields between input/output");
1278 bool misMatch =
false;
1285 Type inputTp =
nullptr;
1289 assert(fid == idx && stt.
getLvlType(lvl) == lt);
1290 inputTp = lvlTps[idx++];
1293 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1295 if (inpElemTp != expElemTp) {
1303 return op->
emitError(
"input/output element-types don't match");
1309 const auto lvlsTp = getLevels().getTypes();
1315 if (getOutValues().getType() != getRetValues().getType())
1316 return emitError(
"output values and return value type mismatch");
1318 for (
auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1319 if (ot.getType() != rt.getType())
1320 return emitError(
"output levels and return levels type mismatch");
1323 const auto lvlsTp = getRetLevels().getTypes();
1329 if (
auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1330 if (
auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1331 if (tp1.getRank() != tp2.getRank())
1332 return emitError(
"unexpected conversion mismatch in rank");
1334 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1335 if (dstEnc && dstEnc.isSlice())
1336 return emitError(
"cannot convert to a sparse tensor slice");
1338 auto shape1 = tp1.getShape();
1339 auto shape2 = tp2.getShape();
1343 for (
Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1344 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1345 return emitError(
"unexpected conversion mismatch in dimension ") << d;
1349 return emitError(
"unexpected type in convert");
1353 if (getType() == getSource().getType())
1358 bool ConvertOp::needsExtraSort() {
1377 if (
auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1378 if (isa<SparseElementsAttr>(constOp.getValue()))
1385 uint64_t inRank = getEncoder().getLvlRank();
1386 uint64_t outRank = getEncoder().getDimRank();
1388 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1389 std::swap(inRank, outRank);
1391 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1392 return emitError(
"Coordinate rank mismatch with encoding");
1399 if (getEncoder().isIdentity()) {
1400 results.assign(getInCrds().begin(), getInCrds().end());
1404 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1405 ? getEncoder().getDimToLvl()
1406 : getEncoder().getLvlToDim();
1408 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1413 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1414 bool sameDef = def && llvm::all_of(getInCrds(), [def](
Value v) {
1420 bool oppositeDir = def.getDirection() != getDirection();
1422 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1423 bool sameCount = def.getNumResults() == getInCrds().size();
1424 if (!oppositeDir || !sameOracle || !sameCount)
1429 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1430 [](
auto valuePair) {
1431 auto [lhs, rhs] = valuePair;
1439 results.append(def.getInCrds().begin(), def.getInCrds().end());
1445 Value val = builder.
create<arith::ConstantIndexOp>(state.location, index);
1446 return build(builder, state, source, val);
1450 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1452 if (
static_cast<uint64_t
>(lvl.value()) >= stt.
getLvlRank())
1453 emitError(
"Level index exceeds the rank of the input sparse tensor");
1458 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1468 cast<RankedTensorType>(getSource().getType()).getRank());
1473 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1477 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1487 auto getIndexAttr = [
this](int64_t lvlSz) {
1492 if (!ShapedType::isDynamic(lvlShape[lvl]))
1493 return getIndexAttr(lvlShape[lvl]);
1499 SparseTensorEncodingAttr dstEnc,
Value source) {
1503 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1506 return build(odsBuilder, odsState, dstTp, source);
1515 if (srcLvlTps.size() != dstLvlTps.size())
1516 return emitError(
"Level rank mismatch between source/dest tensors");
1518 for (
auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1519 if (srcLvlTp != dstLvlTp)
1520 return emitError(
"Level type mismatch between source/dest tensors");
1524 return emitError(
"Crd/Pos width mismatch between source/dest tensors");
1528 return emitError(
"Element type mismatch between source/dest tensors");
1532 for (
auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1533 if (srcLvlSz != dstLvlSz) {
1537 return emitError(
"Level size mismatch between source/dest tensors");
1544 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1545 if (getSource().getType() == getDest().getType())
1548 if (
auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1550 if (def.getSource().getType() == getDest().getType())
1551 return def.getSource();
1556 template <
typename ToBufferOp>
1561 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1563 Type elemTp =
nullptr;
1564 bool withStride =
false;
1565 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1567 }
else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1568 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1570 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1572 }
else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1576 assert(elemTp &&
"unhandled operation.");
1578 bufShape.push_back(ShapedType::kDynamic);
1582 {ShapedType::kDynamic})
1583 : StridedLayoutAttr();
1591 return emitError(
"requested level is out of bounds");
1593 return emitError(
"unexpected type for positions");
1598 ToPositionsOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1602 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1608 return emitError(
"requested level is out of bounds");
1610 return emitError(
"unexpected type for coordinates");
1615 ToCoordinatesOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1619 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1625 return emitError(
"expected sparse tensor with a COO region");
1633 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1641 return emitError(
"unexpected mismatch in element types");
1646 std::optional<Location> loc,
1651 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1656 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1657 return emitError(
"requested dimension out of bound");
1663 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1664 return emitError(
"requested dimension out of bound");
1670 getSpecifier(), getOperation());
1673 template <
typename SpecifierOp>
1675 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1678 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1679 const StorageSpecifierKind kind = getSpecifierKind();
1680 const auto lvl = getLevel();
1682 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1683 return op.getValue();
1689 getSpecifier(), getOperation());
1694 const char *regionName,
1697 unsigned expectedNum = inputTypes.size();
1698 if (numArgs != expectedNum)
1699 return op->
emitError() << regionName <<
" region must have exactly "
1700 << expectedNum <<
" arguments";
1702 for (
unsigned i = 0; i < numArgs; i++) {
1704 if (typ != inputTypes[i])
1705 return op->
emitError() << regionName <<
" region argument " << (i + 1)
1706 <<
" type mismatch";
1709 YieldOp yield = dyn_cast<YieldOp>(term);
1712 <<
" region must end with sparse_tensor.yield";
1713 if (!yield.hasSingleResult() ||
1714 yield.getSingleResult().getType() != outputType)
1715 return op->
emitError() << regionName <<
" region yield type mismatch";
1722 Type leftType = getX().getType();
1723 Type rightType = getY().getType();
1724 Type outputType = getOutput().getType();
1725 Region &overlap = getOverlapRegion();
1726 Region &left = getLeftRegion();
1727 Region &right = getRightRegion();
1731 if (!overlap.
empty()) {
1733 TypeRange{leftType, rightType}, outputType)))
1736 if (!left.
empty()) {
1740 }
else if (getLeftIdentity()) {
1741 if (leftType != outputType)
1742 return emitError(
"left=identity requires first argument to have the same "
1743 "type as the output");
1745 if (!right.
empty()) {
1749 }
else if (getRightIdentity()) {
1750 if (rightType != outputType)
1751 return emitError(
"right=identity requires second argument to have the "
1752 "same type as the output");
1758 Type inputType = getX().getType();
1759 Type outputType = getOutput().getType();
1763 Region &present = getPresentRegion();
1764 if (!present.
empty()) {
1769 Region &absent = getAbsentRegion();
1770 if (!absent.
empty()) {
1776 Block *parent = getOperation()->getBlock();
1778 cast<YieldOp>(absentBlock->
getTerminator()).getSingleResult();
1779 if (
auto arg = dyn_cast<BlockArgument>(absentVal)) {
1780 if (arg.getOwner() == parent)
1781 return emitError(
"absent region cannot yield linalg argument");
1783 if (!isa<arith::ConstantOp>(def) &&
1784 (def->getBlock() == absentBlock || def->getBlock() == parent))
1785 return emitError(
"absent region cannot yield locally computed value");
1791 bool ConcatenateOp::needsExtraSort() {
1796 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](
Value op) {
1803 bool directLowerable =
1804 allSameOrdered && getDimension() == 0 && dstStt.
isIdentity();
1805 return !directLowerable;
1810 const Dimension concatDim = getDimension();
1811 const Dimension dimRank = dstTp.getDimRank();
1813 if (getInputs().size() <= 1)
1814 return emitError(
"Need at least two tensors to concatenate.");
1816 if (concatDim >= dimRank)
1818 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1819 concatDim, dimRank));
1822 const auto i = it.index();
1824 if (srcTp.hasDynamicDimShape())
1825 return emitError(llvm::formatv(
"Input tensor ${0} has dynamic shape", i));
1826 const Dimension srcDimRank = srcTp.getDimRank();
1827 if (srcDimRank != dimRank)
1829 llvm::formatv(
"Input tensor ${0} has a different rank (rank={1}) "
1830 "from the output tensor (rank={2}).",
1831 i, srcDimRank, dimRank));
1834 for (
Dimension d = 0; d < dimRank; d++) {
1835 const Size dstSh = dstTp.getDimShape()[d];
1836 if (d == concatDim) {
1837 if (!ShapedType::isDynamic(dstSh)) {
1842 for (
const auto src : getInputs())
1848 "The concatenation dimension of the output tensor should be the "
1849 "sum of all the concatenation dimensions of the input tensors.");
1853 for (
const auto src : getInputs()) {
1855 if (!ShapedType::isDynamic(prev) && sh != prev)
1856 return emitError(
"All dimensions (expect for the concatenating one) "
1857 "should be equal.");
1868 build(builder, result, curSize, inBuffer, value,
Value());
1874 if (nValue && nValue.value() < 1)
1875 return emitOpError(
"n must be not less than 1");
1882 if (stt.
getLvlRank() != 1 +
static_cast<Level>(getLvlCoords().size()))
1883 return emitOpError(
"incorrect number of coordinates");
1887 void ForeachOp::build(
1892 build(builder, result, initArgs.
getTypes(), tensor, initArgs, order);
1904 blockArgTypes.append(initArgs.
getTypes().begin(), initArgs.
getTypes().end());
1909 auto ®ion = *result.
regions.front();
1911 builder.
createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1912 bodyBuilder(builder, result.
location,
1920 const Dimension dimRank = t.getDimRank();
1921 const auto args = getBody()->getArguments();
1923 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1924 return emitError(
"Level traverse order does not match tensor's level rank");
1926 if (dimRank + 1 + getInitArgs().size() != args.size())
1927 return emitError(
"Unmatched number of arguments in the block");
1929 if (getNumResults() != getInitArgs().size())
1930 return emitError(
"Mismatch in number of init arguments and results");
1932 if (getResultTypes() != getInitArgs().getTypes())
1933 return emitError(
"Mismatch in types of init arguments and results");
1936 auto yield = cast<YieldOp>(getBody()->getTerminator());
1937 if (yield.getNumOperands() != getNumResults() ||
1938 yield.getOperands().getTypes() != getResultTypes())
1939 return emitError(
"Mismatch in types of yield values and results");
1943 if (args[d].getType() != iTp)
1945 llvm::formatv(
"Expecting Index type for argument at index {0}", d));
1947 const auto elemTp = t.getElementType();
1948 const auto valueTp = args[dimRank].getType();
1949 if (elemTp != valueTp)
1950 emitError(llvm::formatv(
"Unmatched element type between input tensor and "
1951 "block argument, expected:{0}, got: {1}",
1959 return getInputCoo();
1969 emitError(
"Expected COO sparse tensors only");
1972 emitError(
"Unmatched dim2lvl map between input and result COO");
1977 emitError(
"Unmatched storage format between input and result COO");
1983 Type inputType = getX().getType();
1984 Region &formula = getRegion();
1986 TypeRange{inputType, inputType}, inputType);
1991 Type inputType = getX().getType();
1992 Type boolType = b.getI1Type();
1993 Region &formula = getRegion();
2002 emitError(llvm::formatv(
"Expected rank(perm_map) > 1, got {0}", nx));
2005 emitError(llvm::formatv(
"Expected a permutation map, got {0}", xPerm));
2014 const auto checkDim = [&](
Value v,
Size minSize,
const char *message) {
2016 if (!ShapedType::isDynamic(sh) && sh < minSize)
2017 emitError(llvm::formatv(
"{0} got {1} < {2}", message, sh, minSize));
2019 uint64_t n = cn.value();
2021 if (
auto nyAttr = getNyAttr())
2022 ny = nyAttr.getInt();
2023 checkDim(getXy(), n * (nx + ny),
2024 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2025 for (
Value opnd : getYs())
2026 checkDim(opnd, n,
"Expected dimension(y) >= n");
2035 IterSpaceType IteratorType::getIterSpaceType()
const {
2040 IteratorType IterSpaceType::getIteratorType()
const {
2060 "expect larger level upper bound than lower bound");
2068 IntegerAttr &lvlHiAttr) {
2085 p << lo <<
" to " << hi;
2091 IntegerAttr lvlHi) {
2092 unsigned lo = lvlLo.getValue().getZExtValue();
2093 unsigned hi = lvlHi.getValue().getZExtValue();
2102 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2105 adaptor.getHiLvl()));
2110 if (getLoLvl() >= getHiLvl())
2111 return emitOpError(
"expected smaller level low than level high");
2114 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2116 "parent iterator should be specified iff level lower bound equals 0");
2120 IterSpaceType spaceTp = getResultSpace().getType();
2121 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2123 "mismatch in parent iterator encoding and iteration space encoding.");
2125 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2126 return emitOpError(
"parent iterator should be used to extract an "
2127 "iteration space from a consecutive level.");
2138 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2148 if (isa<SparseTensorEncodingAttr>(attr)) {
2150 return AliasResult::OverridableAlias;
2157 void SparseTensorDialect::initialize() {
2158 addInterface<SparseTensorAsmDialectInterface>();
2160 #define GET_ATTRDEF_LIST
2161 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2164 #define GET_TYPEDEF_LIST
2165 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2169 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2171 declarePromisedInterfaces<
2172 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2173 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2174 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2177 #define GET_OP_CLASSES
2178 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2180 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
The possible results of an alias query.
@ NoAlias
The two locations do not alias at all.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A wrapper around RankedTensorType, which has three goals:
MLIRContext * getContext() const
Type getElementType() const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
SmallVector< COOSegment > getCOOSegments() const
Returns a list of COO segments in the sparse tensor types.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
bool isLooseCompressedLT(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
llvm::hash_code hash_value(LevelType lt)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
This class represents an efficient way to signal success or failure.
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const