27 #include "llvm/ADT/Bitset.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/FormatVariadic.h"
31 #define GET_ATTRDEF_CLASSES
32 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
43 #define GET_TYPEDEF_CLASSES
44 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
81 if (dimShape.has_value()) {
85 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86 memrefShape.assign(lvlShape.begin(),
87 lvlShape.begin() + enc.getBatchLvlRank());
90 memrefShape.push_back(ShapedType::kDynamic);
106 const auto lvlTypes = enc.getLvlTypes();
107 const Level lvlRank = enc.getLvlRank();
113 for (
Level l = 0; l < lvlRank; ) {
114 const auto lt = lvlTypes[l];
123 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124 if (!cooSegsRef.front().isSoA) {
127 l = cooSegsRef.front().lvlRange.second;
133 cooSegsRef = cooSegsRef.drop_front();
173 return callback(specType, fieldIdx, fieldKind, lvl, lt);
175 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
179 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
181 llvm_unreachable(
"unrecognized field kind");
186 unsigned numFields = 0;
196 unsigned numFields = 0;
208 std::pair<FieldIndex, unsigned>
210 std::optional<Level> lvl)
const {
214 assert(lvl.has_value());
215 const Level cooStart = enc.getAoSCOOStart();
216 const Level lvlRank = enc.getLvlRank();
217 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
219 stride = lvlRank - cooStart;
225 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
234 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
241 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242 return isDynamic(v) ? std::nullopt
243 : std::make_optional(
static_cast<uint64_t
>(v));
246 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset()
const {
247 return getStatic(getOffset());
250 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride()
const {
251 return getStatic(getStride());
254 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize()
const {
255 return getStatic(getSize());
258 bool SparseTensorDimSliceAttr::isCompletelyDynamic()
const {
259 return isDynamic(getOffset()) && isDynamic(getStride()) &&
260 isDynamic(getSize());
263 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264 return isDynamic(v) ?
"?" : std::to_string(v);
268 assert(getImpl() &&
"Uninitialized SparseTensorDimSliceAttr");
270 os << getStaticString(getOffset());
272 os << getStaticString(getSize());
274 os << getStaticString(getStride());
285 if (parseResult.has_value()) {
286 if (parseResult.value().succeeded() && result < 0) {
289 "expect positive value or ? for slice offset/size/stride");
292 return parseResult.value();
296 result = SparseTensorDimSliceAttr::kDynamic;
301 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
313 offset, size, stride);
318 int64_t offset, int64_t size, int64_t stride) {
319 if (!isDynamic(offset) && offset < 0)
320 return emitError() <<
"expect non-negative value or ? for slice offset";
321 if (!isDynamic(size) && size <= 0)
322 return emitError() <<
"expect positive value or ? for slice size";
323 if (!isDynamic(stride) && stride <= 0)
324 return emitError() <<
"expect positive value or ? for slice stride";
328 SparseTensorEncodingAttr
329 SparseTensorEncodingAttr::withDimToLvl(
AffineMap dimToLvl)
const {
330 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
333 getCrdWidth(), getExplicitVal(), getImplicitVal());
336 SparseTensorEncodingAttr
337 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc)
const {
338 return withDimToLvl(enc ? enc.getDimToLvl() :
AffineMap());
341 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl()
const {
345 SparseTensorEncodingAttr
346 SparseTensorEncodingAttr::withBitWidths(
unsigned posWidth,
347 unsigned crdWidth)
const {
348 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
350 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351 crdWidth, getExplicitVal(), getImplicitVal());
354 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths()
const {
355 return withBitWidths(0, 0);
358 SparseTensorEncodingAttr
359 SparseTensorEncodingAttr::withExplicitVal(
Attribute explicitVal)
const {
360 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
362 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363 getCrdWidth(), explicitVal, getImplicitVal());
366 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal()
const {
370 SparseTensorEncodingAttr
371 SparseTensorEncodingAttr::withImplicitVal(
Attribute implicitVal)
const {
372 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
374 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375 getCrdWidth(), getExplicitVal(), implicitVal);
378 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal()
const {
382 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
385 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
389 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices()
const {
393 uint64_t SparseTensorEncodingAttr::getBatchLvlRank()
const {
395 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
396 return std::distance(lastBatch, lvlTypes.rend());
400 return !getImpl() || llvm::all_of(getLvlTypes(),
isDenseLT);
403 bool SparseTensorEncodingAttr::isAllOrdered()
const {
404 return !getImpl() || llvm::all_of(getLvlTypes(),
isOrderedLT);
407 Type SparseTensorEncodingAttr::getCrdElemType()
const {
415 Type SparseTensorEncodingAttr::getPosElemType()
const {
423 MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
429 MemRefType SparseTensorEncodingAttr::getPosMemRefType(
435 bool SparseTensorEncodingAttr::isIdentity()
const {
436 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
440 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
443 Dimension SparseTensorEncodingAttr::getDimRank()
const {
444 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
445 const auto dimToLvl = getDimToLvl();
446 return dimToLvl ? dimToLvl.
getNumDims() : getLvlRank();
449 Level SparseTensorEncodingAttr::getLvlRank()
const {
450 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
451 return getLvlTypes().size();
457 assert(l < getLvlRank() &&
"Level is out of bounds");
458 return getLvlTypes()[l];
461 bool SparseTensorEncodingAttr::isSlice()
const {
462 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
463 return !getDimSlices().empty();
466 SparseTensorDimSliceAttr
467 SparseTensorEncodingAttr::getDimSlice(
Dimension dim)
const {
468 assert(isSlice() &&
"Is not a slice");
469 const auto dimSlices = getDimSlices();
470 assert(dim < dimSlices.size() &&
"Dimension is out of bounds");
471 return dimSlices[dim];
474 std::optional<uint64_t>
475 SparseTensorEncodingAttr::getStaticDimSliceOffset(
Dimension dim)
const {
476 return getDimSlice(dim).getStaticOffset();
479 std::optional<uint64_t>
480 SparseTensorEncodingAttr::getStaticDimSliceStride(
Dimension dim)
const {
481 return getDimSlice(dim).getStaticStride();
484 std::optional<uint64_t>
485 SparseTensorEncodingAttr::getStaticLvlSliceOffset(
Level lvl)
const {
486 return getStaticDimSliceOffset(
toDim(*
this, lvl));
489 std::optional<uint64_t>
490 SparseTensorEncodingAttr::getStaticLvlSliceStride(
Level lvl)
const {
491 return getStaticDimSliceStride(
toDim(*
this, lvl));
496 CrdTransDirectionKind dir)
const {
502 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
506 for (
unsigned r = 0; r < rank; r++) {
507 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ?
toDim(*
this, r)
509 ret.push_back(srcShape[trans]);
516 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
519 dimRep.reserve(srcShape.size());
520 for (int64_t sz : srcShape) {
521 if (!ShapedType::isDynamic(sz)) {
535 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536 ret.push_back(c.getValue() + 1);
538 if (
auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
542 if (
auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543 ret.push_back(bound.getValue());
547 ret.push_back(ShapedType::kDynamic);
550 assert(ret.size() == rank);
557 CrdTransDirectionKind dir)
const {
562 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
564 auto transOp = builder.
create<CrdTranslateOp>(loc, retType, crds, dir, *
this);
565 return transOp.getOutCrds();
580 unsigned posWidth = 0;
581 unsigned crdWidth = 0;
586 "explicitVal",
"implicitVal"};
589 auto *it = find(keys, attrName);
590 if (it == keys.end()) {
594 unsigned keyWordIndex = it - keys.begin();
599 switch (keyWordIndex) {
602 auto res = cParser.parseDimLvlMap();
605 const auto &dlm = *res;
607 const Level lvlRank = dlm.getLvlRank();
608 for (
Level lvl = 0; lvl < lvlRank; lvl++)
609 lvlTypes.push_back(dlm.getLvlType(lvl));
611 const Dimension dimRank = dlm.getDimRank();
612 for (
Dimension dim = 0; dim < dimRank; dim++)
613 dimSlices.push_back(dlm.getDimSlice(dim));
617 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618 return static_cast<bool>(slice.getImpl());
620 if (llvm::any_of(dimSlices, isDefined)) {
621 const auto defaultSlice =
623 for (
Dimension dim = 0; dim < dimRank; dim++)
624 if (!isDefined(dimSlices[dim]))
625 dimSlices[dim] = defaultSlice;
630 dimToLvl = dlm.getDimToLvlMap(parser.
getContext());
631 lvlToDim = dlm.getLvlToDimMap(parser.
getContext());
638 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
641 "expected an integral position bitwidth");
644 posWidth = intAttr.getInt();
651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
654 "expected an integral index bitwidth");
657 crdWidth = intAttr.getInt();
664 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665 explicitVal = result;
666 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667 explicitVal = result;
668 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669 explicitVal = result;
672 "expected a numeric value for explicitVal");
681 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682 implicitVal = result;
683 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684 implicitVal = result;
685 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686 implicitVal = result;
689 "expected a numeric value for implicitVal");
707 if (!lvlToDim || lvlToDim.
isEmpty()) {
710 return parser.
getChecked<SparseTensorEncodingAttr>(
711 parser.
getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712 explicitVal, implicitVal, dimSlices);
716 auto map =
static_cast<AffineMap>(getDimToLvl());
720 printer <<
"<{ map = ";
721 printSymbols(map, printer);
723 printDimensions(map, printer, getDimSlices());
725 printLevels(map, printer, getLvlTypes());
729 printer <<
", posWidth = " << getPosWidth();
731 printer <<
", crdWidth = " << getCrdWidth();
732 if (getExplicitVal()) {
733 printer <<
", explicitVal = " << getExplicitVal();
735 if (getImplicitVal())
736 printer <<
", implicitVal = " << getImplicitVal();
740 void SparseTensorEncodingAttr::printSymbols(
AffineMap &map,
745 for (
unsigned i = 0, n = map.
getNumSymbols() - 1; i < n; i++)
746 printer <<
's' << i <<
", ";
752 void SparseTensorEncodingAttr::printDimensions(
755 if (!dimSlices.empty()) {
756 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
757 printer <<
'd' << i <<
" : " << dimSlices[i] <<
", ";
759 printer <<
'd' << map.
getNumDims() - 1 <<
" : "
763 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
764 printer <<
'd' << i <<
", ";
772 for (
unsigned i = 0, n = map.
getNumResults() - 1; i < n; i++) {
789 return emitError() <<
"unexpected position bitwidth: " << posWidth;
791 return emitError() <<
"unexpected coordinate bitwidth: " << crdWidth;
794 auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(),
isSingletonLT);
795 while (it != lvlTypes.end()) {
796 if (it == lvlTypes.begin() ||
798 return emitError() <<
"expected compressed or loose_compressed level "
799 "before singleton level";
801 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(),
isSingletonLT);
802 if (!std::all_of(it, curCOOEnd,
804 return emitError() <<
"expected all singleton lvlTypes "
805 "following a singleton level";
807 if (!std::all_of(it, curCOOEnd, [it](
LevelType i) {
811 return emitError() <<
"expected all singleton lvlTypes stored in the "
812 "same memory layout (SoA vs AoS).";
817 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
818 if (!std::all_of(lastBatch, lvlTypes.rend(),
isBatchLT))
819 return emitError() <<
"Batch lvlType can only be leading levels.";
822 auto soaLvls = llvm::make_filter_range(lvlTypes, [](
LevelType lt) {
825 if (llvm::any_of(soaLvls, [](
LevelType lt) {
828 return emitError() <<
"SoA is only applicable to singleton lvlTypes.";
832 if (
auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(),
isNOutOfMLT);
833 it != std::end(lvlTypes)) {
834 if (it != lvlTypes.end() - 1)
835 return emitError() <<
"expected n_out_of_m to be the last level type";
836 if (!std::all_of(lvlTypes.begin(), it,
837 [](
LevelType i) { return isDenseLT(i); }))
838 return emitError() <<
"expected all dense lvlTypes "
839 "before a n_out_of_m level";
843 <<
"expected 1xm block structure for n_out_of_m level";
846 unsigned coefficient = 0;
847 for (
const auto &elem : sizes) {
849 if (elem != coefficient && coefficient != 0) {
850 return emitError() <<
"expected only one blocked level "
851 "with the same coefficients";
856 if (coefficient !=
getM(*it)) {
857 return emitError() <<
"expected coeffiencts of Affine expressions "
858 "to be equal to m of n_out_of_m level";
867 const Level lvlRank = lvlTypes.size();
869 return emitError() <<
"expected a non-empty array for lvlTypes";
875 <<
"level-rank mismatch between dimToLvl and lvlTypes: "
880 return emitError() <<
"failed to infer lvlToDim from dimToLvl";
881 if (lvlToDim && (inferRes != lvlToDim))
882 return emitError() <<
"expected lvlToDim to be an inverse of dimToLvl";
883 if (dimRank > lvlRank)
884 return emitError() <<
"unexpected dimToLvl mapping from " << dimRank
885 <<
" to " << lvlRank;
887 if (!dimSlices.empty()) {
888 if (dimSlices.size() != dimRank)
890 <<
"dimension-rank mismatch between dimSlices and dimToLvl: "
891 << dimSlices.size() <<
" != " << dimRank;
894 if (dimRank != lvlRank)
896 <<
"dimSlices expected dimension-rank to match level-rank: "
897 << dimRank <<
" != " << lvlRank;
902 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
907 if (failed(
verify(
emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
908 getPosWidth(), getCrdWidth(), getExplicitVal(),
909 getImplicitVal(), getDimSlices())))
914 const Dimension dimRank = dimShape.size();
916 return emitError() <<
"expected non-scalar sparse tensor";
917 if (getDimRank() != dimRank)
919 <<
"dimension-rank mismatch between encoding and tensor shape: "
920 << getDimRank() <<
" != " << dimRank;
921 if (
auto expVal = getExplicitVal()) {
922 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
923 if (attrType != elementType) {
924 return emitError() <<
"explicit value type mismatch between encoding and "
925 <<
"tensor element type: " << attrType
926 <<
" != " << elementType;
929 if (
auto impVal = getImplicitVal()) {
930 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
931 if (attrType != elementType) {
932 return emitError() <<
"implicit value type mismatch between encoding and "
933 <<
"tensor element type: " << attrType
934 <<
" != " << elementType;
937 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
938 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
939 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
940 if ((impFVal && impFVal.getValue().isNonZero()) ||
941 (impIntVal && !impIntVal.getValue().isZero()) ||
942 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
943 impComplexVal.getReal().isNonZero()))) {
944 return emitError() <<
"implicit value must be zero";
950 Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart()
const {
952 assert(coo.size() == 1 || coo.empty());
953 if (!coo.empty() && coo.front().isAoS()) {
954 return coo.front().lvlRange.first;
960 mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments()
const {
962 if (getLvlRank() <= 1)
967 while (l < getLvlRank()) {
970 auto cur = lts.begin() + l;
971 auto end = std::find_if(cur + 1, lts.end(), [](
LevelType lt) {
972 return !lt.isa<LevelFormat::Singleton>();
974 unsigned cooLen = std::distance(cur, end);
980 ret.push_back(
COOSegment{std::make_pair(l, l + cooLen),
999 if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
1001 for (
Level l = startLvl + 1; l < lvlRank; ++l)
1002 if (!isSingletonLvl(l))
1007 return !
isUnique || isUniqueLvl(lvlRank - 1);
1013 lvlTypes.reserve(lvlRank);
1020 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1026 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1027 getCrdWidth(), getExplicitVal(), getImplicitVal());
1035 SparseTensorEncodingAttr
1037 if (
auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1038 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1039 if (
auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1040 return mdtp.getEncoding();
1046 auto map =
static_cast<AffineMap>(dimToLvl);
1063 lvlExprs.reserve(numLvls);
1066 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1067 for (
unsigned i = 0, n = numLvls; i < n; i++) {
1069 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1072 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1073 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1074 "expected only one floordiv for each dimension");
1079 components.push_back(binOp.getRHS());
1081 lvlExprComponents[pos] = components;
1083 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1085 "expected floordiv before mod");
1090 assert(
false &&
"expected floordiv or mod");
1100 for (
auto &components : lvlExprComponents) {
1101 assert(components.second.size() == 3 &&
1102 "expected 3 components to build lvlExprs");
1107 lvlExprs.push_back(addOp);
1114 "expected dimToLvl to be block sparsity for calling getBlockSize");
1117 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1119 blockSize.push_back(
1120 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1123 blockSize.push_back(0);
1132 std::map<unsigned, int64_t> coeffientMap;
1133 bool hasBlock =
false;
1135 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
1137 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1138 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1139 if (!dimOp || !conOp || conOp.getValue() <= 0)
1142 auto pos = dimOp.getPosition();
1145 if (coeffientMap.find(pos) != coeffientMap.end())
1148 coeffientMap[pos] = conOp.getValue();
1151 if (coeffientMap.find(pos) == coeffientMap.end())
1154 if (conOp.getValue() != coeffientMap[pos])
1160 }
else if (
auto dimOp = dyn_cast<AffineDimExpr>(result)) {
1161 auto pos = dimOp.getPosition();
1163 if (coeffientMap.find(pos) != coeffientMap.end())
1165 coeffientMap[pos] = 0;
1174 auto hasNonIdentityMap = [](
Value v) {
1179 return llvm::any_of(op->
getOperands(), hasNonIdentityMap) ||
1180 llvm::any_of(op->
getResults(), hasNonIdentityMap);
1185 assert(enc.isPermutation() &&
"Non permutation map not supported");
1186 if (
const auto dimToLvl = enc.getDimToLvl())
1194 assert(enc.isPermutation() &&
"Non permutation map not supported");
1195 if (
const auto lvlToDim = enc.getLvlToDim())
1205 static SparseTensorEncodingAttr
1208 for (
auto lt : enc.getLvlTypes())
1212 enc.getContext(), lts,
1222 enc.getDimSlices());
1225 StorageSpecifierType
1230 StorageSpecifierType
1233 SparseTensorEncodingAttr encoding) {
1252 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1254 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1256 "redundant level argument for querying value memory size");
1259 const auto enc = md.getType().getEncoding();
1260 const Level lvlRank = enc.getLvlRank();
1262 if (mdKind == StorageSpecifierKind::DimOffset ||
1263 mdKind == StorageSpecifierKind::DimStride)
1265 return op->
emitError(
"requested slice data on non-slice tensor");
1267 if (mdKind != StorageSpecifierKind::ValMemSize) {
1269 return op->
emitError(
"missing level argument");
1271 const Level l = lvl.value();
1273 return op->
emitError(
"requested level is out of bounds");
1275 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1277 "requested position memory size on a singleton level");
1293 llvm_unreachable(
"Unrecognizable FieldKind");
1298 RankedTensorType valTp,
1301 return op->
emitError(
"the sparse-tensor must have static shape");
1303 return op->
emitError(
"the sparse-tensor must have an encoding attribute");
1309 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1311 unsigned expCOORank = stt.
getLvlRank() - cooStartLvl;
1312 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1313 op->
emitError(
"input/output trailing COO level-ranks don't match");
1320 return op->
emitError(
"inconsistent number of fields between input/output");
1323 bool misMatch =
false;
1330 Type inputTp =
nullptr;
1334 assert(fid == idx && stt.
getLvlType(lvl) == lt);
1335 inputTp = lvlTps[idx++];
1338 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1340 if (inpElemTp != expElemTp) {
1348 return op->
emitError(
"input/output element-types don't match");
1354 const auto lvlsTp = getLevels().getTypes();
1361 return emitError(
"output values and return value type mismatch");
1363 for (
auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1364 if (ot.getType() != rt.getType())
1365 return emitError(
"output levels and return levels type mismatch");
1368 const auto lvlsTp = getRetLevels().getTypes();
1374 if (
auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().
getType())) {
1375 if (
auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().
getType())) {
1376 if (tp1.getRank() != tp2.getRank())
1377 return emitError(
"unexpected conversion mismatch in rank");
1379 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1380 if (dstEnc && dstEnc.isSlice())
1381 return emitError(
"cannot convert to a sparse tensor slice");
1383 auto shape1 = tp1.getShape();
1384 auto shape2 = tp2.getShape();
1388 for (
Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1389 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1390 return emitError(
"unexpected conversion mismatch in dimension ") << d;
1394 return emitError(
"unexpected type in convert");
1403 bool ConvertOp::needsExtraSort() {
1422 if (
auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1423 if (isa<SparseElementsAttr>(constOp.getValue()))
1430 uint64_t inRank = getEncoder().getLvlRank();
1431 uint64_t outRank = getEncoder().getDimRank();
1433 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1434 std::swap(inRank, outRank);
1436 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1437 return emitError(
"Coordinate rank mismatch with encoding");
1442 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1444 if (getEncoder().isIdentity()) {
1445 results.assign(getInCrds().begin(), getInCrds().end());
1449 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1450 ? getEncoder().getDimToLvl()
1451 : getEncoder().getLvlToDim();
1453 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1458 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1459 bool sameDef = def && llvm::all_of(getInCrds(), [def](
Value v) {
1465 bool oppositeDir = def.getDirection() != getDirection();
1467 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1468 bool sameCount = def.getNumResults() == getInCrds().size();
1469 if (!oppositeDir || !sameOracle || !sameCount)
1474 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1475 [](
auto valuePair) {
1476 auto [lhs, rhs] = valuePair;
1484 results.append(def.getInCrds().begin(), def.getInCrds().end());
1490 Value val = builder.
create<arith::ConstantIndexOp>(state.location, index);
1491 return build(builder, state, source, val);
1495 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1497 if (
static_cast<uint64_t
>(lvl.value()) >= stt.
getLvlRank())
1498 emitError(
"Level index exceeds the rank of the input sparse tensor");
1503 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1513 cast<RankedTensorType>(getSource().
getType()).getRank());
1518 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1522 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1532 auto getIndexAttr = [
this](int64_t lvlSz) {
1537 if (!ShapedType::isDynamic(lvlShape[lvl]))
1538 return getIndexAttr(lvlShape[lvl]);
1544 SparseTensorEncodingAttr dstEnc,
Value source) {
1548 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1551 return build(odsBuilder, odsState, dstTp, source);
1560 if (srcLvlTps.size() != dstLvlTps.size())
1561 return emitError(
"Level rank mismatch between source/dest tensors");
1563 for (
auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1564 if (srcLvlTp != dstLvlTp)
1565 return emitError(
"Level type mismatch between source/dest tensors");
1569 return emitError(
"Crd/Pos width mismatch between source/dest tensors");
1573 return emitError(
"Element type mismatch between source/dest tensors");
1577 for (
auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1578 if (srcLvlSz != dstLvlSz) {
1582 return emitError(
"Level size mismatch between source/dest tensors");
1589 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1593 if (
auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1595 if (def.getSource().getType() == getDest().
getType())
1596 return def.getSource();
1601 template <
typename ToBufferOp>
1606 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1608 Type elemTp =
nullptr;
1609 bool withStride =
false;
1610 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1612 }
else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1613 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1615 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1617 }
else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1621 assert(elemTp &&
"unhandled operation.");
1623 bufShape.push_back(ShapedType::kDynamic);
1627 {ShapedType::kDynamic})
1628 : StridedLayoutAttr();
1636 return emitError(
"requested level is out of bounds");
1638 return emitError(
"unexpected type for positions");
1643 ToPositionsOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1647 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1653 return emitError(
"requested level is out of bounds");
1655 return emitError(
"unexpected type for coordinates");
1660 ToCoordinatesOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc,
1664 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1670 return emitError(
"expected sparse tensor with a COO region");
1674 LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1678 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1686 return emitError(
"unexpected mismatch in element types");
1690 LogicalResult ToValuesOp::inferReturnTypes(
MLIRContext *ctx,
1691 std::optional<Location> loc,
1696 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1701 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1702 return emitError(
"requested dimension out of bound");
1708 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1709 return emitError(
"requested dimension out of bound");
1715 getSpecifier(), getOperation());
1718 template <
typename SpecifierOp>
1720 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1723 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1724 const StorageSpecifierKind kind = getSpecifierKind();
1725 const auto lvl = getLevel();
1727 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1728 return op.getValue();
1734 getSpecifier(), getOperation());
1739 const char *regionName,
1742 unsigned expectedNum = inputTypes.size();
1743 if (numArgs != expectedNum)
1744 return op->
emitError() << regionName <<
" region must have exactly "
1745 << expectedNum <<
" arguments";
1747 for (
unsigned i = 0; i < numArgs; i++) {
1749 if (typ != inputTypes[i])
1750 return op->
emitError() << regionName <<
" region argument " << (i + 1)
1751 <<
" type mismatch";
1754 YieldOp yield = dyn_cast<YieldOp>(term);
1757 <<
" region must end with sparse_tensor.yield";
1758 if (!yield.hasSingleResult() ||
1759 yield.getSingleResult().getType() != outputType)
1760 return op->
emitError() << regionName <<
" region yield type mismatch";
1767 Type leftType = getX().getType();
1768 Type rightType = getY().getType();
1769 Type outputType = getOutput().getType();
1770 Region &overlap = getOverlapRegion();
1771 Region &left = getLeftRegion();
1772 Region &right = getRightRegion();
1776 if (!overlap.
empty()) {
1778 TypeRange{leftType, rightType}, outputType)))
1781 if (!left.
empty()) {
1785 }
else if (getLeftIdentity()) {
1786 if (leftType != outputType)
1787 return emitError(
"left=identity requires first argument to have the same "
1788 "type as the output");
1790 if (!right.
empty()) {
1794 }
else if (getRightIdentity()) {
1795 if (rightType != outputType)
1796 return emitError(
"right=identity requires second argument to have the "
1797 "same type as the output");
1803 Type inputType = getX().getType();
1804 Type outputType = getOutput().getType();
1808 Region &present = getPresentRegion();
1809 if (!present.
empty()) {
1814 Region &absent = getAbsentRegion();
1815 if (!absent.
empty()) {
1821 Block *parent = getOperation()->getBlock();
1823 cast<YieldOp>(absentBlock->
getTerminator()).getSingleResult();
1824 if (
auto arg = dyn_cast<BlockArgument>(absentVal)) {
1825 if (arg.getOwner() == parent)
1826 return emitError(
"absent region cannot yield linalg argument");
1828 if (!isa<arith::ConstantOp>(def) &&
1829 (def->getBlock() == absentBlock || def->getBlock() == parent))
1830 return emitError(
"absent region cannot yield locally computed value");
1836 bool ConcatenateOp::needsExtraSort() {
1841 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](
Value op) {
1848 bool directLowerable =
1849 allSameOrdered && getDimension() == 0 && dstStt.
isIdentity();
1850 return !directLowerable;
1855 const Dimension concatDim = getDimension();
1856 const Dimension dimRank = dstTp.getDimRank();
1858 if (getInputs().size() <= 1)
1859 return emitError(
"Need at least two tensors to concatenate.");
1861 if (concatDim >= dimRank)
1863 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1864 concatDim, dimRank));
1867 const auto i = it.index();
1869 if (srcTp.hasDynamicDimShape())
1870 return emitError(llvm::formatv(
"Input tensor ${0} has dynamic shape", i));
1871 const Dimension srcDimRank = srcTp.getDimRank();
1872 if (srcDimRank != dimRank)
1874 llvm::formatv(
"Input tensor ${0} has a different rank (rank={1}) "
1875 "from the output tensor (rank={2}).",
1876 i, srcDimRank, dimRank));
1879 for (
Dimension d = 0; d < dimRank; d++) {
1880 const Size dstSh = dstTp.getDimShape()[d];
1881 if (d == concatDim) {
1882 if (!ShapedType::isDynamic(dstSh)) {
1887 for (
const auto src : getInputs())
1893 "The concatenation dimension of the output tensor should be the "
1894 "sum of all the concatenation dimensions of the input tensors.");
1898 for (
const auto src : getInputs()) {
1900 if (!ShapedType::isDynamic(prev) && sh != prev)
1901 return emitError(
"All dimensions (expect for the concatenating one) "
1902 "should be equal.");
1913 build(builder, result, curSize, inBuffer, value,
Value());
1919 if (nValue && nValue.value() < 1)
1920 return emitOpError(
"n must be not less than 1");
1927 if (stt.
getLvlRank() != 1 +
static_cast<Level>(getLvlCoords().size()))
1928 return emitOpError(
"incorrect number of coordinates");
1932 void ForeachOp::build(
1937 build(builder, result, initArgs.
getTypes(), tensor, initArgs, order);
1949 blockArgTypes.append(initArgs.
getTypes().begin(), initArgs.
getTypes().end());
1954 auto ®ion = *result.
regions.front();
1956 builder.
createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1957 bodyBuilder(builder, result.
location,
1965 const Dimension dimRank = t.getDimRank();
1966 const auto args = getBody()->getArguments();
1968 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1969 return emitError(
"Level traverse order does not match tensor's level rank");
1971 if (dimRank + 1 + getInitArgs().size() != args.size())
1972 return emitError(
"Unmatched number of arguments in the block");
1974 if (getNumResults() != getInitArgs().size())
1975 return emitError(
"Mismatch in number of init arguments and results");
1977 if (getResultTypes() != getInitArgs().getTypes())
1978 return emitError(
"Mismatch in types of init arguments and results");
1981 auto yield = cast<YieldOp>(getBody()->getTerminator());
1982 if (yield.getNumOperands() != getNumResults() ||
1983 yield.getOperands().getTypes() != getResultTypes())
1984 return emitError(
"Mismatch in types of yield values and results");
1990 llvm::formatv(
"Expecting Index type for argument at index {0}", d));
1992 const auto elemTp = t.getElementType();
1993 const auto valueTp = args[dimRank].getType();
1994 if (elemTp != valueTp)
1995 emitError(llvm::formatv(
"Unmatched element type between input tensor and "
1996 "block argument, expected:{0}, got: {1}",
2004 return getInputCoo();
2014 emitError(
"Expected COO sparse tensors only");
2017 emitError(
"Unmatched dim2lvl map between input and result COO");
2022 emitError(
"Unmatched storage format between input and result COO");
2028 Type inputType = getX().getType();
2029 Region &formula = getRegion();
2031 TypeRange{inputType, inputType}, inputType);
2036 Type inputType = getX().getType();
2037 Type boolType = b.getI1Type();
2038 Region &formula = getRegion();
2047 emitError(llvm::formatv(
"Expected rank(perm_map) > 1, got {0}", nx));
2050 emitError(llvm::formatv(
"Expected a permutation map, got {0}", xPerm));
2059 const auto checkDim = [&](
Value v,
Size minSize,
const char *message) {
2061 if (!ShapedType::isDynamic(sh) && sh < minSize)
2062 emitError(llvm::formatv(
"{0} got {1} < {2}", message, sh, minSize));
2064 uint64_t n = cn.value();
2066 if (
auto nyAttr = getNyAttr())
2067 ny = nyAttr.getInt();
2068 checkDim(getXy(), n * (nx + ny),
2069 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2070 for (
Value opnd : getYs())
2071 checkDim(opnd, n,
"Expected dimension(y) >= n");
2080 IterSpaceType IteratorType::getIterSpaceType()
const {
2085 IteratorType IterSpaceType::getIteratorType()
const {
2105 "expect larger level upper bound than lower bound");
2113 IntegerAttr &lvlHiAttr) {
2130 p << lo <<
" to " << hi;
2136 IntegerAttr lvlHi) {
2137 unsigned lo = lvlLo.getValue().getZExtValue();
2138 unsigned hi = lvlHi.getValue().getZExtValue();
2152 ParseResult crdList =
2155 if (parser.parseArgument(definedArgs.emplace_back()))
2157 definedSet.set(cnt);
2165 "parsed more value than expected.");
2167 if (failed(crdList)) {
2170 "expecting SSA value or \"_\" for level coordinates");
2172 assert(definedArgs.size() == definedSet.
count());
2179 if (definedSet.
empty())
2182 for (
unsigned i = 0; i < size; i++) {
2183 if (definedSet[i]) {
2184 p << blocksArgs.front();
2185 blocksArgs = blocksArgs.drop_front();
2192 assert(blocksArgs.empty());
2205 for (
auto &coord : coords)
2209 state.addAttribute(
"crdUsedLvls",
2226 if (iterators.size() != spaces.size())
2229 "mismatch in number of sparse iterators and sparse spaces");
2234 size_t numCrds = coords.size();
2242 blockArgs.append(coords);
2248 if (iterSpaceTps.size() != spaces.size())
2250 "mismatch in number of iteration space operands "
2251 "and iteration space types");
2253 for (
auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2254 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2257 "expected sparse_tensor.iter_space type for "
2258 "iteration space operands");
2259 it.type = spaceTp.getIteratorType();
2274 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2277 "mismatch in number of iteration arguments and return values");
2280 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2302 size_t numCrds = coords.size();
2310 blockArgs.append(coords);
2318 if (iterSpaceTps.size() != spaces.size())
2320 "mismatch in number of iteration space operands "
2321 "and iteration space types");
2331 state.operands.append(spacesVals);
2336 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2339 "mismatch in number of iteration arguments and return values");
2342 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
2351 LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2356 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2359 adaptor.getHiLvl()));
2364 if (getLoLvl() >= getHiLvl())
2365 return emitOpError(
"expected smaller level low than level high");
2368 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2370 "parent iterator should be specified iff level lower bound equals 0");
2374 IterSpaceType spaceTp = getExtractedSpace().getType();
2375 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2377 "mismatch in parent iterator encoding and iteration space encoding.");
2379 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2380 return emitOpError(
"parent iterator should be used to extract an "
2381 "iteration space from a consecutive level.");
2389 auto itTp = getIterator().getType();
2392 return emitOpError(
"mismatch in tensor encoding and iterator encoding.");
2395 return emitOpError(
"must use last-level iterator to extract values. ");
2406 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2407 for (
unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2408 if (
auto crd = iterateOp.getLvlCrd(i)) {
2409 if (crd->getUsers().empty())
2410 toRemove.set(crd->getArgNumber());
2417 if (toRemove.none())
2421 iterateOp.setCrdUsedLvls(newUsedLvls);
2422 iterateOp.getBody()->eraseArguments(toRemove);
2435 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.
getType()).getSpaceDim();
2438 return build(builder, odsState, iterSpace, initArgs, set);
2455 for (
Value v : initArgs)
2459 for (
unsigned i = 0, e = crdUsedLvls.
count(); i < e; i++)
2464 llvm::cast<IterSpaceType>(iterSpace.
getType()).getIteratorType(),
2475 if (iters.size() != 1)
2477 "expected only one iterator/iteration space");
2479 iterArgs.append(iters);
2500 StringRef prefix =
"") {
2501 assert(blocksArgs.size() == initializers.size() &&
2502 "expected same length of arguments and initializers");
2503 if (initializers.empty())
2507 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
2508 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2513 template <
typename SparseLoopOp>
2517 "mismatch in number of loop-carried values and defined values");
2519 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2520 return op.
emitOpError(
"required out-of-bound coordinates");
2529 p <<
" " << getIterator() <<
" in " << getIterSpace();
2530 if (!getCrdUsedLvls().empty()) {
2537 p <<
" : " << getIterSpace().getType() <<
" ";
2538 if (!getInitArgs().empty())
2543 !getInitArgs().empty());
2546 LogicalResult IterateOp::verifyRegions() {
2547 if (getIterator().
getType() != getIterSpace().
getType().getIteratorType())
2548 return emitOpError(
"mismatch in iterator and iteration space type");
2549 if (getNumRegionIterArgs() != getNumResults())
2551 "mismatch in number of basic block args and defined values");
2553 auto initArgs = getInitArgs();
2554 auto iterArgs = getRegionIterArgs();
2555 auto yieldVals = getYieldedValues();
2556 auto opResults = getResults();
2557 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2558 opResults.size()})) {
2559 return emitOpError() <<
"number mismatch between iter args and results.";
2562 for (
auto [i, init, iter, yield, ret] :
2564 if (init.getType() != ret.getType())
2565 return emitOpError() <<
"types mismatch between " << i
2566 <<
"th iter operand and defined value";
2567 if (iter.getType() != ret.getType())
2568 return emitOpError() <<
"types mismatch between " << i
2569 <<
"th iter region arg and defined value";
2570 if (yield.getType() != ret.getType())
2571 return emitOpError() <<
"types mismatch between " << i
2572 <<
"th yield value and defined value";
2582 return getInitArgsMutable();
2586 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2589 std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2590 return cast<sparse_tensor::YieldOp>(
2591 getRegion().getBlocks().front().getTerminator())
2592 .getResultsMutable();
2595 std::optional<ResultRange> IterateOp::getLoopResults() {
return getResults(); }
2598 return getInitArgs();
2605 regions.push_back(
RegionSuccessor(&getRegion(), getRegionIterArgs()));
2612 unsigned numCases) {
2614 cast<IterSpaceType>(iterSpaces.front().
getType()).getSpaceDim();
2623 return CoIterateOp::build(builder, odsState, initArgs.
getTypes(), iterSpaces,
2624 initArgs, set, cases,
2639 {static_cast<int32_t>(spaces.size()),
2640 static_cast<int32_t>(result.types.size())}));
2655 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].
getType());
2656 definedIts[i].type = spaceTp.getIteratorType();
2658 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2677 llvm::interleaveComma(getIterSpaces(), p, [&](
auto s) { p << s; });
2680 if (!getCrdUsedLvls().empty()) {
2688 p <<
" : (" << getIterSpaces().getTypes() <<
")";
2689 if (!getInitArgs().empty())
2690 p.printArrowTypeList(getInitArgs().getTypes());
2692 for (
unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2696 getRegionDefinedSpace(idx));
2698 p.printRegion(getRegion(idx),
false,
2699 !getInitArgs().empty());
2703 ValueRange CoIterateOp::getYieldedValues(
unsigned regionIdx) {
2704 return cast<sparse_tensor::YieldOp>(
2705 getRegion(regionIdx).getBlocks().front().getTerminator())
2709 LogicalResult CoIterateOp::verifyRegions() {
2710 for (
unsigned r = 0, e = getNumRegions(); r < e; r++) {
2711 if (getNumRegionIterArgs() != getNumResults())
2713 "mismatch in number of basic block args and defined values");
2715 auto initArgs = getInitArgs();
2716 auto iterArgs = getRegionIterArgs(r);
2717 auto yieldVals = getYieldedValues(r);
2718 auto opResults = getResults();
2719 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2720 opResults.size()})) {
2721 return emitOpError()
2722 <<
"number mismatch between iter args and results on " << r
2726 for (
auto [i, init, iter, yield, ret] :
2728 if (init.getType() != ret.getType())
2729 return emitOpError()
2730 <<
"types mismatch between " << i
2731 <<
"th iter operand and defined value on " << r <<
"th region";
2732 if (iter.getType() != ret.getType())
2733 return emitOpError() <<
"types mismatch between " << i
2734 <<
"th iter region arg and defined value on " << r
2736 if (yield.getType() != ret.getType())
2737 return emitOpError()
2738 <<
"types mismatch between " << i
2739 <<
"th yield value and defined value on " << r <<
"th region";
2743 auto cases = getRegionDefinedSpaces();
2744 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2745 if (set.size() != getNumRegions())
2746 return emitOpError(
"contains duplicated cases.");
2753 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2754 for (
Region &r : getCaseRegions())
2755 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2770 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2780 if (isa<SparseTensorEncodingAttr>(attr)) {
2782 return AliasResult::OverridableAlias;
2789 void SparseTensorDialect::initialize() {
2790 addInterface<SparseTensorAsmDialectInterface>();
2792 #define GET_ATTRDEF_LIST
2793 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2796 #define GET_TYPEDEF_LIST
2797 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2801 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2803 declarePromisedInterfaces<
2804 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2805 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2806 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2809 #define GET_OP_CLASSES
2810 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2812 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isPermutation(std::vector< PermutationTy > permutation)
static MLIRContext * getContext(OpFoldResult val)
bool isUnique(It begin, It end)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t >> dimShape)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
Base type for affine expression.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
The possible results of an alias query.
@ NoAlias
The two locations do not alias at all.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
OpAsmDialectInterface(Dialect *dialect)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
MLIRContext * getContext() const
Type getElementType() const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
ArrayRef< LevelType > getLvlTypes() const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
uint64_t getN(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
llvm::hash_code hash_value(LevelType lt)
RankedTensorType getRankedTensorType(T &&t)
Convenience method to abbreviate casting getType().
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const