26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
51 return llvm::hash_value(
static_cast<uint64_t
>(lt));
79 if (dimShape.has_value()) {
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
88 memrefShape.push_back(ShapedType::kDynamic);
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
111 for (
Level l = 0; l < lvlRank; ) {
112 const auto lt = lvlTypes[l];
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
125 l = cooSegsRef.front().lvlRange.second;
131 cooSegsRef = cooSegsRef.drop_front();
159 const Type posMemType = MemRefType::get(memrefShape, stt.
getPosType());
161 const Type crdMemType = MemRefType::get(memrefShape, stt.
getCrdType());
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
179 llvm_unreachable(
"unrecognized field kind");
184 unsigned numFields = 0;
194 unsigned numFields = 0;
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl)
const {
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217 stride = lvlRank - cooStart;
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(
int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(
static_cast<uint64_t
>(v));
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset()
const {
245 return getStatic(getOffset());
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride()
const {
249 return getStatic(getStride());
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize()
const {
253 return getStatic(getSize());
256bool SparseTensorDimSliceAttr::isCompletelyDynamic()
const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ?
"?" : std::to_string(v);
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os)
const {
266 assert(getImpl() &&
"Uninitialized SparseTensorDimSliceAttr");
268 os << getStaticString(getOffset());
270 os << getStaticString(getSize());
272 os << getStaticString(getStride());
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer)
const {
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() &&
result < 0) {
287 "expect positive value or ? for slice offset/size/stride");
290 return parseResult.value();
294 result = SparseTensorDimSliceAttr::kDynamic;
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
311 offset, size, stride);
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() <<
"expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() <<
"expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() <<
"expect positive value or ? for slice stride";
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl)
const {
328 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc)
const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl()
const {
340 return withDimToLvl(AffineMap());
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(
unsigned posWidth,
345 unsigned crdWidth)
const {
346 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths()
const {
353 return withBitWidths(0, 0);
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal)
const {
358 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal()
const {
365 return withExplicitVal(Attribute());
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal)
const {
370 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal()
const {
377 return withImplicitVal(Attribute());
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices)
const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices()
const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank()
const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
397bool SparseTensorEncodingAttr::isAllDense()
const {
398 return !getImpl() || llvm::all_of(getLvlTypes(),
isDenseLT);
401bool SparseTensorEncodingAttr::isAllOrdered()
const {
402 return !getImpl() || llvm::all_of(getLvlTypes(),
isOrderedLT);
405Type SparseTensorEncodingAttr::getCrdElemType()
const {
409 return IntegerType::get(
getContext(), getCrdWidth());
413Type SparseTensorEncodingAttr::getPosElemType()
const {
417 return IntegerType::get(
getContext(), getPosWidth());
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape)
const {
424 return MemRefType::get(shape, getCrdElemType());
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape)
const {
430 return MemRefType::get(shape, getPosElemType());
433bool SparseTensorEncodingAttr::isIdentity()
const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437bool SparseTensorEncodingAttr::isPermutation()
const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441Dimension SparseTensorEncodingAttr::getDimRank()
const {
442 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.
getNumDims() : getLvlRank();
447Level SparseTensorEncodingAttr::getLvlRank()
const {
448 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
455 assert(l < getLvlRank() &&
"Level is out of bounds");
456 return getLvlTypes()[l];
459bool SparseTensorEncodingAttr::isSlice()
const {
460 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(
Dimension dim)
const {
466 assert(isSlice() &&
"Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() &&
"Dimension is out of bounds");
469 return dimSlices[dim];
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(
Dimension dim)
const {
474 return getDimSlice(dim).getStaticOffset();
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(
Dimension dim)
const {
479 return getDimSlice(dim).getStaticStride();
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(
Level lvl)
const {
484 return getStaticDimSliceOffset(
toDim(*
this, lvl));
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(
Level lvl)
const {
489 return getStaticDimSliceStride(
toDim(*
this, lvl));
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir)
const {
496 return SmallVector<int64_t>(srcShape);
498 SmallVector<int64_t> ret;
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
504 for (
unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ?
toDim(*
this, r)
507 ret.push_back(srcShape[trans]);
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
528 for (AffineExpr exp : transMap.
getResults()) {
531 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
533 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
534 ret.push_back(c.getValue() + 1);
536 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
537 mod && mod.getKind() == AffineExprKind::Mod) {
540 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
541 ret.push_back(bound.getValue());
545 ret.push_back(ShapedType::kDynamic);
548 assert(ret.size() == rank);
553SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
555 CrdTransDirectionKind dir)
const {
559 SmallVector<Type> retType(
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563 CrdTranslateOp::create(builder, loc, retType, crds, dir, *
this);
564 return transOp.getOutCrds();
567Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
575 SmallVector<LevelType> lvlTypes;
576 SmallVector<SparseTensorDimSliceAttr> dimSlices;
577 AffineMap dimToLvl = {};
578 AffineMap lvlToDim = {};
579 unsigned posWidth = 0;
580 unsigned crdWidth = 0;
581 Attribute explicitVal;
582 Attribute implicitVal;
584 SmallVector<StringRef, 5> keys = {
"map",
"posWidth",
"crdWidth",
585 "explicitVal",
"implicitVal"};
588 auto *it = find(keys, attrName);
589 if (it == keys.end()) {
593 unsigned keyWordIndex = it - keys.begin();
598 switch (keyWordIndex) {
601 auto res = cParser.parseDimLvlMap();
604 const auto &dlm = *res;
606 const Level lvlRank = dlm.getLvlRank();
607 for (
Level lvl = 0; lvl < lvlRank; lvl++)
608 lvlTypes.push_back(dlm.getLvlType(lvl));
610 const Dimension dimRank = dlm.getDimRank();
611 for (
Dimension dim = 0; dim < dimRank; dim++)
612 dimSlices.push_back(dlm.getDimSlice(dim));
616 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
617 return static_cast<bool>(slice.getImpl());
619 if (llvm::any_of(dimSlices, isDefined)) {
620 const auto defaultSlice =
621 SparseTensorDimSliceAttr::get(parser.
getContext());
622 for (
Dimension dim = 0; dim < dimRank; dim++)
623 if (!isDefined(dimSlices[dim]))
624 dimSlices[dim] = defaultSlice;
629 dimToLvl = dlm.getDimToLvlMap(parser.
getContext());
630 lvlToDim = dlm.getLvlToDimMap(parser.
getContext());
637 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
640 "expected an integral position bitwidth");
643 posWidth = intAttr.getInt();
650 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
653 "expected an integral index bitwidth");
656 crdWidth = intAttr.getInt();
663 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
671 "expected a numeric value for explicitVal");
680 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
688 "expected a numeric value for implicitVal");
706 if (!lvlToDim || lvlToDim.
isEmpty()) {
709 return parser.
getChecked<SparseTensorEncodingAttr>(
710 parser.
getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
711 explicitVal, implicitVal, dimSlices);
714void SparseTensorEncodingAttr::print(AsmPrinter &printer)
const {
715 auto map =
static_cast<AffineMap
>(getDimToLvl());
719 printer <<
"<{ map = ";
720 printSymbols(map, printer);
722 printDimensions(map, printer, getDimSlices());
724 printLevels(map, printer, getLvlTypes());
728 printer <<
", posWidth = " << getPosWidth();
730 printer <<
", crdWidth = " << getCrdWidth();
731 if (getExplicitVal()) {
732 printer <<
", explicitVal = " << getExplicitVal();
734 if (getImplicitVal())
735 printer <<
", implicitVal = " << getImplicitVal();
739void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
740 AsmPrinter &printer)
const {
744 for (
unsigned i = 0, n = map.
getNumSymbols() - 1; i < n; i++)
745 printer <<
's' << i <<
", ";
751void SparseTensorEncodingAttr::printDimensions(
752 AffineMap &map, AsmPrinter &printer,
753 ArrayRef<SparseTensorDimSliceAttr> dimSlices)
const {
754 if (!dimSlices.empty()) {
755 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
756 printer <<
'd' << i <<
" : " << dimSlices[i] <<
", ";
758 printer <<
'd' << map.
getNumDims() - 1 <<
" : "
762 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
763 printer <<
'd' << i <<
", ";
769void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
770 ArrayRef<LevelType> lvlTypes)
const {
771 for (
unsigned i = 0, n = map.
getNumResults() - 1; i < n; i++) {
782LogicalResult SparseTensorEncodingAttr::verify(
784 AffineMap dimToLvl, AffineMap lvlToDim,
unsigned posWidth,
785 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
786 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
788 return emitError() <<
"unexpected position bitwidth: " << posWidth;
790 return emitError() <<
"unexpected coordinate bitwidth: " << crdWidth;
794 while (it != lvlTypes.end()) {
795 if (it == lvlTypes.begin() ||
797 return emitError() <<
"expected compressed or loose_compressed level "
798 "before singleton level";
800 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(),
isSingletonLT);
802 return emitError() <<
"expected all singleton lvlTypes "
803 "following a singleton level";
805 if (!std::all_of(it, curCOOEnd, [it](
LevelType i) {
809 return emitError() <<
"expected all singleton lvlTypes stored in the "
810 "same memory layout (SoA vs AoS).";
815 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
816 if (!std::all_of(lastBatch, lvlTypes.rend(),
isBatchLT))
817 return emitError() <<
"Batch lvlType can only be leading levels.";
820 auto soaLvls = llvm::make_filter_range(lvlTypes, [](
LevelType lt) {
823 if (llvm::any_of(soaLvls, [](
LevelType lt) {
826 return emitError() <<
"SoA is only applicable to singleton lvlTypes.";
833 for (
auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
835 return emitError() <<
"dense level cannot follow a non-unique level";
839 if (
auto it = llvm::find_if(lvlTypes,
isNOutOfMLT);
840 it != std::end(lvlTypes)) {
841 if (it != lvlTypes.end() - 1)
842 return emitError() <<
"expected n_out_of_m to be the last level type";
843 if (!std::all_of(lvlTypes.begin(), it,
isDenseLT))
844 return emitError() <<
"expected all dense lvlTypes "
845 "before a n_out_of_m level";
849 <<
"expected 1xm block structure for n_out_of_m level";
852 unsigned coefficient = 0;
853 for (
const auto &elem : sizes) {
855 if (elem != coefficient && coefficient != 0) {
856 return emitError() <<
"expected only one blocked level "
857 "with the same coefficients";
862 if (coefficient !=
getM(*it)) {
863 return emitError() <<
"expected coeffiencts of Affine expressions "
864 "to be equal to m of n_out_of_m level";
873 const Level lvlRank = lvlTypes.size();
875 return emitError() <<
"expected a non-empty array for lvlTypes";
881 <<
"level-rank mismatch between dimToLvl and lvlTypes: "
886 return emitError() <<
"failed to infer lvlToDim from dimToLvl";
887 if (lvlToDim && (inferRes != lvlToDim))
888 return emitError() <<
"expected lvlToDim to be an inverse of dimToLvl";
889 if (dimRank > lvlRank)
890 return emitError() <<
"unexpected dimToLvl mapping from " << dimRank
891 <<
" to " << lvlRank;
893 if (!dimSlices.empty()) {
894 if (dimSlices.size() != dimRank)
896 <<
"dimension-rank mismatch between dimSlices and dimToLvl: "
897 << dimSlices.size() <<
" != " << dimRank;
900 if (dimRank != lvlRank)
902 <<
"dimSlices expected dimension-rank to match level-rank: "
903 << dimRank <<
" != " << lvlRank;
908LogicalResult SparseTensorEncodingAttr::verifyEncoding(
909 ArrayRef<Size> dimShape, Type elementType,
914 getPosWidth(), getCrdWidth(), getExplicitVal(),
915 getImplicitVal(), getDimSlices())))
920 const Dimension dimRank = dimShape.size();
922 return emitError() <<
"expected non-scalar sparse tensor";
923 if (getDimRank() != dimRank)
925 <<
"dimension-rank mismatch between encoding and tensor shape: "
926 << getDimRank() <<
" != " << dimRank;
927 if (
auto expVal = getExplicitVal()) {
928 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
929 if (attrType != elementType) {
930 return emitError() <<
"explicit value type mismatch between encoding and "
931 <<
"tensor element type: " << attrType
932 <<
" != " << elementType;
935 if (
auto impVal = getImplicitVal()) {
936 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
937 if (attrType != elementType) {
938 return emitError() <<
"implicit value type mismatch between encoding and "
939 <<
"tensor element type: " << attrType
940 <<
" != " << elementType;
943 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
944 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
945 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
946 if ((impFVal && impFVal.getValue().isNonZero()) ||
947 (impIntVal && !impIntVal.getValue().isZero()) ||
948 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
949 impComplexVal.getReal().isNonZero()))) {
950 return emitError() <<
"implicit value must be zero";
956Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart()
const {
957 SmallVector<COOSegment> coo = getCOOSegments();
958 assert(coo.size() == 1 || coo.empty());
959 if (!coo.empty() && coo.front().isAoS()) {
960 return coo.front().lvlRange.first;
965SmallVector<COOSegment>
966mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments()
const {
967 SmallVector<COOSegment> ret;
968 if (getLvlRank() <= 1)
971 ArrayRef<LevelType> lts = getLvlTypes();
973 while (l < getLvlRank()) {
976 auto cur = lts.begin() + l;
977 auto end = std::find_if(cur + 1, lts.end(), [](
LevelType lt) {
978 return !lt.isa<LevelFormat::Singleton>();
980 unsigned cooLen = std::distance(cur, end);
986 ret.push_back(
COOSegment{std::make_pair(l, l + cooLen),
1007 for (
Level l = startLvl + 1; l < lvlRank; ++l)
1019 lvlTypes.reserve(lvlRank);
1026 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1031 auto enc = SparseTensorEncodingAttr::get(
1041SparseTensorEncodingAttr
1043 if (
auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1044 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1045 if (
auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1046 return mdtp.getEncoding();
1052 auto map =
static_cast<AffineMap>(dimToLvl);
1069 lvlExprs.reserve(numLvls);
1072 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1073 for (
unsigned i = 0, n = numLvls; i < n; i++) {
1075 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1078 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1079 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1080 "expected only one floordiv for each dimension");
1085 components.push_back(binOp.getRHS());
1087 lvlExprComponents[pos] = components;
1089 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1090 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1091 "expected floordiv before mod");
1096 assert(
false &&
"expected floordiv or mod");
1106 for (
auto &components : lvlExprComponents) {
1107 assert(components.second.size() == 3 &&
1108 "expected 3 components to build lvlExprs");
1113 lvlExprs.push_back(addOp);
1120 "expected dimToLvl to be block sparsity for calling getBlockSize");
1123 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1125 blockSize.push_back(
1126 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1129 blockSize.push_back(0);
1138 std::map<unsigned, int64_t> coeffientMap;
1139 bool hasBlock =
false;
1141 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1143 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1144 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1145 if (!dimOp || !conOp || conOp.getValue() <= 0)
1148 auto pos = dimOp.getPosition();
1151 auto [it,
inserted] = coeffientMap.try_emplace(pos);
1155 it->second = conOp.getValue();
1158 auto it = coeffientMap.find(pos);
1159 if (it == coeffientMap.end())
1162 if (conOp.getValue() != it->second)
1168 }
else if (
auto dimOp = dyn_cast<AffineDimExpr>(
result)) {
1169 auto pos = dimOp.getPosition();
1171 if (!coeffientMap.try_emplace(pos, 0).second)
1181 auto hasNonIdentityMap = [](
Value v) {
1186 return llvm::any_of(op->
getOperands(), hasNonIdentityMap) ||
1187 llvm::any_of(op->
getResults(), hasNonIdentityMap);
1192 assert(enc.isPermutation() &&
"Non permutation map not supported");
1193 if (
const auto dimToLvl = enc.getDimToLvl())
1201 assert(enc.isPermutation() &&
"Non permutation map not supported");
1202 if (
const auto lvlToDim = enc.getLvlToDim())
1212static SparseTensorEncodingAttr
1215 for (
auto lt : enc.getLvlTypes())
1218 return SparseTensorEncodingAttr::get(
1219 enc.getContext(), lts,
1229 enc.getDimSlices());
1233StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1240 SparseTensorEncodingAttr encoding) {
1259 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1261 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1263 "redundant level argument for querying value memory size");
1266 const auto enc = md.getType().getEncoding();
1267 const Level lvlRank = enc.getLvlRank();
1269 if (mdKind == StorageSpecifierKind::DimOffset ||
1270 mdKind == StorageSpecifierKind::DimStride)
1272 return op->
emitError(
"requested slice data on non-slice tensor");
1274 if (mdKind != StorageSpecifierKind::ValMemSize) {
1276 return op->
emitError(
"missing level argument");
1278 const Level l = lvl.value();
1280 return op->
emitError(
"requested level is out of bounds");
1282 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1284 "requested position memory size on a singleton level");
1300 llvm_unreachable(
"Unrecognizable FieldKind");
1305 RankedTensorType valTp,
1308 return op->
emitError(
"the sparse-tensor must have static shape");
1310 return op->
emitError(
"the sparse-tensor must have an encoding attribute");
1316 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1318 unsigned expCOORank = stt.
getLvlRank() - cooStartLvl;
1319 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1320 return op->
emitError(
"input/output trailing COO level-ranks don't match");
1327 return op->
emitError(
"inconsistent number of fields between input/output");
1330 bool misMatch =
false;
1337 Type inputTp =
nullptr;
1341 assert(fid == idx && stt.
getLvlType(lvl) == lt);
1342 inputTp = lvlTps[idx++];
1345 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1347 if (inpElemTp != expElemTp) {
1355 return op->
emitError(
"input/output element-types don't match");
1359LogicalResult AssembleOp::verify() {
1360 RankedTensorType valuesTp = getValues().getType();
1361 const auto lvlsTp = getLevels().getTypes();
1366LogicalResult DisassembleOp::verify() {
1368 return emitError(
"output values and return value type mismatch");
1370 for (
auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1371 if (ot.getType() != rt.getType())
1372 return emitError(
"output levels and return levels type mismatch");
1374 RankedTensorType valuesTp = getRetValues().getType();
1375 const auto lvlsTp = getRetLevels().getTypes();
1380LogicalResult ConvertOp::verify() {
1381 RankedTensorType tp1 = getSource().getType();
1382 RankedTensorType tp2 = getDest().getType();
1383 if (tp1.getRank() != tp2.getRank())
1384 return emitError(
"unexpected conversion mismatch in rank");
1386 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1387 if (dstEnc && dstEnc.isSlice())
1388 return emitError(
"cannot convert to a sparse tensor slice");
1390 auto shape1 = tp1.getShape();
1391 auto shape2 = tp2.getShape();
1395 for (
Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1396 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1397 return emitError(
"unexpected conversion mismatch in dimension ") << d;
1401OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1407bool ConvertOp::needsExtraSort() {
1426 if (
auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1427 if (isa<SparseElementsAttr>(constOp.getValue()))
1433LogicalResult CrdTranslateOp::verify() {
1434 uint64_t inRank = getEncoder().getLvlRank();
1435 uint64_t outRank = getEncoder().getDimRank();
1437 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1438 std::swap(inRank, outRank);
1440 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1441 return emitError(
"Coordinate rank mismatch with encoding");
1446LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1447 SmallVectorImpl<OpFoldResult> &results) {
1448 if (getEncoder().isIdentity()) {
1449 results.assign(getInCrds().begin(), getInCrds().end());
1453 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1454 ? getEncoder().getDimToLvl()
1455 : getEncoder().getLvlToDim();
1457 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1462 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1463 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1469 bool oppositeDir = def.getDirection() != getDirection();
1471 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1472 bool sameCount = def.getNumResults() == getInCrds().size();
1473 if (!oppositeDir || !sameOracle || !sameCount)
1478 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1479 [](
auto valuePair) {
1480 auto [
lhs,
rhs] = valuePair;
1488 results.append(def.getInCrds().begin(), def.getInCrds().end());
1492void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1495 return build(builder, state, source, val);
1498LogicalResult LvlOp::verify() {
1499 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1501 if (
static_cast<uint64_t
>(lvl.value()) >= stt.
getLvlRank())
1503 "Level index exceeds the rank of the input sparse tensor");
1508std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1518 cast<RankedTensorType>(getSource().
getType()).getRank());
1522OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1523 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1527 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1537 auto getIndexAttr = [
this](int64_t lvlSz) {
1538 return IntegerAttr::get(IndexType::get(
getContext()), APInt(64, lvlSz));
1542 if (ShapedType::isStatic(lvlShape[lvl]))
1543 return getIndexAttr(lvlShape[lvl]);
1548void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1549 SparseTensorEncodingAttr dstEnc, Value source) {
1551 SmallVector<int64_t> srcLvlShape = srcStt.
getLvlShape();
1552 SmallVector<int64_t> dstDimShape =
1553 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1555 RankedTensorType::get(dstDimShape, srcStt.
getElementType(), dstEnc);
1556 return build(odsBuilder, odsState, dstTp, source);
1559LogicalResult ReinterpretMapOp::verify() {
1562 ArrayRef<LevelType> srcLvlTps = srcStt.
getLvlTypes();
1563 ArrayRef<LevelType> dstLvlTps = dstStt.
getLvlTypes();
1565 if (srcLvlTps.size() != dstLvlTps.size())
1566 return emitError(
"Level rank mismatch between source/dest tensors");
1568 for (
auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1569 if (srcLvlTp != dstLvlTp)
1570 return emitError(
"Level type mismatch between source/dest tensors");
1574 return emitError(
"Crd/Pos width mismatch between source/dest tensors");
1578 return emitError(
"Element type mismatch between source/dest tensors");
1580 SmallVector<Size> srcLvlShape = srcStt.
getLvlShape();
1581 SmallVector<Size> dstLvlShape = dstStt.
getLvlShape();
1582 for (
auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1583 if (srcLvlSz != dstLvlSz) {
1587 return emitError(
"Level size mismatch between source/dest tensors");
1594OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1598 if (
auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1600 if (def.getSource().getType() == getDest().
getType())
1601 return def.getSource();
1606template <
typename ToBufferOp>
1611 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1613 Type elemTp =
nullptr;
1614 bool withStride =
false;
1615 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1617 }
else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1618 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1620 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1622 }
else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1626 assert(elemTp &&
"unhandled operation.");
1628 bufShape.push_back(ShapedType::kDynamic);
1630 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1632 {ShapedType::kDynamic})
1633 : StridedLayoutAttr();
1634 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1638LogicalResult ToPositionsOp::verify() {
1641 return emitError(
"requested level is out of bounds");
1643 return emitError(
"unexpected type for positions");
1648ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1650 OpaqueProperties prop, RegionRange region,
1651 SmallVectorImpl<mlir::Type> &ret) {
1655LogicalResult ToCoordinatesOp::verify() {
1658 return emitError(
"requested level is out of bounds");
1660 return emitError(
"unexpected type for coordinates");
1665ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1667 OpaqueProperties prop, RegionRange region,
1668 SmallVectorImpl<mlir::Type> &ret) {
1672LogicalResult ToCoordinatesBufferOp::verify() {
1675 return emitError(
"expected sparse tensor with a COO region");
1679LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1680 MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops,
1681 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1682 SmallVectorImpl<mlir::Type> &ret) {
1687LogicalResult ToValuesOp::verify() {
1691 return emitError(
"unexpected mismatch in element types");
1695LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1696 std::optional<Location> loc,
1698 OpaqueProperties prop,
1700 SmallVectorImpl<mlir::Type> &ret) {
1704LogicalResult ToSliceOffsetOp::verify() {
1705 auto rank =
getSlice().getType().getRank();
1706 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1707 return emitError(
"requested dimension out of bound");
1711LogicalResult ToSliceStrideOp::verify() {
1712 auto rank =
getSlice().getType().getRank();
1713 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1714 return emitError(
"requested dimension out of bound");
1718LogicalResult GetStorageSpecifierOp::verify() {
1720 getSpecifier(), getOperation());
1723template <
typename SpecifierOp>
1725 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1728OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1729 const StorageSpecifierKind kind = getSpecifierKind();
1730 const auto lvl = getLevel();
1732 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1733 return op.getValue();
1737LogicalResult SetStorageSpecifierOp::verify() {
1739 getSpecifier(), getOperation());
1744 const char *regionName,
1747 unsigned expectedNum = inputTypes.size();
1748 if (numArgs != expectedNum)
1749 return op->emitError() << regionName <<
" region must have exactly "
1750 << expectedNum <<
" arguments";
1752 for (
unsigned i = 0; i < numArgs; i++) {
1754 if (typ != inputTypes[i])
1755 return op->emitError() << regionName <<
" region argument " << (i + 1)
1756 <<
" type mismatch";
1760 return op->emitError() << regionName
1761 <<
" region must end with a terminator";
1764 YieldOp yield = dyn_cast<YieldOp>(term);
1766 return op->emitError() << regionName
1767 <<
" region must end with sparse_tensor.yield";
1768 if (!yield.hasSingleResult() ||
1769 yield.getSingleResult().getType() != outputType)
1770 return op->emitError() << regionName <<
" region yield type mismatch";
1775LogicalResult BinaryOp::verify() {
1776 NamedAttrList attrs = (*this)->getAttrs();
1777 Type leftType = getX().getType();
1778 Type rightType = getY().getType();
1779 Type outputType = getOutput().getType();
1780 Region &overlap = getOverlapRegion();
1781 Region &left = getLeftRegion();
1782 Region &right = getRightRegion();
1786 if (!overlap.
empty()) {
1788 TypeRange{leftType, rightType}, outputType)))
1791 if (!left.
empty()) {
1795 }
else if (getLeftIdentity()) {
1796 if (leftType != outputType)
1797 return emitError(
"left=identity requires first argument to have the same "
1798 "type as the output");
1800 if (!right.
empty()) {
1804 }
else if (getRightIdentity()) {
1805 if (rightType != outputType)
1806 return emitError(
"right=identity requires second argument to have the "
1807 "same type as the output");
1812LogicalResult UnaryOp::verify() {
1813 Type inputType = getX().getType();
1814 Type outputType = getOutput().getType();
1818 Region &present = getPresentRegion();
1819 if (!present.
empty()) {
1824 Region &absent = getAbsentRegion();
1825 if (!absent.
empty()) {
1831 Block *parent = getOperation()->getBlock();
1833 cast<YieldOp>(absentBlock->
getTerminator()).getSingleResult();
1834 if (
auto arg = dyn_cast<BlockArgument>(absentVal)) {
1835 if (arg.getOwner() == parent)
1836 return emitError(
"absent region cannot yield linalg argument");
1838 if (!isa<arith::ConstantOp>(def) &&
1839 (def->getBlock() == absentBlock || def->getBlock() == parent))
1840 return emitError(
"absent region cannot yield locally computed value");
1846bool ConcatenateOp::needsExtraSort() {
1851 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1858 bool directLowerable =
1859 allSameOrdered && getDimension() == 0 && dstStt.
isIdentity();
1860 return !directLowerable;
1863LogicalResult ConcatenateOp::verify() {
1865 const Dimension concatDim = getDimension();
1866 const Dimension dimRank = dstTp.getDimRank();
1868 if (getInputs().size() <= 1)
1869 return emitError(
"Need at least two tensors to concatenate.");
1871 if (concatDim >= dimRank)
1873 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1874 concatDim, dimRank));
1876 for (
const auto &it : llvm::enumerate(getInputs())) {
1877 const auto i = it.index();
1879 if (srcTp.hasDynamicDimShape())
1880 return emitError(llvm::formatv(
"Input tensor ${0} has dynamic shape", i));
1881 const Dimension srcDimRank = srcTp.getDimRank();
1882 if (srcDimRank != dimRank)
1884 llvm::formatv(
"Input tensor ${0} has a different rank (rank={1}) "
1885 "from the output tensor (rank={2}).",
1886 i, srcDimRank, dimRank));
1889 for (
Dimension d = 0; d < dimRank; d++) {
1890 const Size dstSh = dstTp.getDimShape()[d];
1891 if (d == concatDim) {
1892 if (ShapedType::isStatic(dstSh)) {
1897 for (
const auto src : getInputs())
1903 "The concatenation dimension of the output tensor should be the "
1904 "sum of all the concatenation dimensions of the input tensors.");
1908 for (
const auto src : getInputs()) {
1910 if (ShapedType::isStatic(prev) && sh != prev)
1911 return emitError(
"All dimensions (expect for the concatenating one) "
1912 "should be equal.");
1921void PushBackOp::build(OpBuilder &builder, OperationState &
result,
1922 Value curSize, Value inBuffer, Value value) {
1923 build(builder,
result, curSize, inBuffer, value, Value());
1926LogicalResult PushBackOp::verify() {
1927 if (Value n =
getN()) {
1929 if (nValue && nValue.value() < 1)
1935LogicalResult CompressOp::verify() {
1937 if (stt.
getLvlRank() != 1 +
static_cast<Level>(getLvlCoords().size()))
1938 return emitOpError(
"incorrect number of coordinates");
1942void ForeachOp::build(
1943 OpBuilder &builder, OperationState &
result, Value tensor,
1947 build(builder,
result, initArgs.
getTypes(), tensor, initArgs, order);
1955 SmallVector<Type> blockArgTypes(dimRank, builder.
getIndexType());
1959 blockArgTypes.append(initArgs.
getTypes().begin(), initArgs.
getTypes().end());
1961 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.
getLoc());
1963 OpBuilder::InsertionGuard guard(builder);
1964 auto ®ion = *
result.regions.front();
1966 builder.
createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1967 bodyBuilder(builder,
result.location,
1973LogicalResult ForeachOp::verify() {
1975 const Dimension dimRank = t.getDimRank();
1976 const auto args = getBody()->getArguments();
1978 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1979 return emitError(
"Level traverse order does not match tensor's level rank");
1981 if (dimRank + 1 + getInitArgs().size() != args.size())
1982 return emitError(
"Unmatched number of arguments in the block");
1984 if (getNumResults() != getInitArgs().size())
1985 return emitError(
"Mismatch in number of init arguments and results");
1987 if (getResultTypes() != getInitArgs().getTypes())
1988 return emitError(
"Mismatch in types of init arguments and results");
1991 auto yield = cast<YieldOp>(getBody()->getTerminator());
1992 if (yield.getNumOperands() != getNumResults() ||
1993 yield.getOperands().getTypes() != getResultTypes())
1994 return emitError(
"Mismatch in types of yield values and results");
1996 const auto iTp = IndexType::get(
getContext());
2000 llvm::formatv(
"Expecting Index type for argument at index {0}", d));
2002 const auto elemTp = t.getElementType();
2003 const auto valueTp = args[dimRank].getType();
2004 if (elemTp != valueTp)
2006 llvm::formatv(
"Unmatched element type between input tensor and "
2007 "block argument, expected:{0}, got: {1}",
2012OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2015 return getInputCoo();
2020LogicalResult ReorderCOOOp::verify() {
2025 return emitError(
"Expected COO sparse tensors only");
2028 return emitError(
"Unmatched dim2lvl map between input and result COO");
2033 return emitError(
"Unmatched storage format between input and result COO");
2038LogicalResult ReduceOp::verify() {
2039 Type inputType = getX().getType();
2040 Region &formula = getRegion();
2042 TypeRange{inputType, inputType}, inputType);
2045LogicalResult SelectOp::verify() {
2047 Type inputType = getX().getType();
2048 Type boolType =
b.getI1Type();
2049 Region &formula = getRegion();
2054LogicalResult SortOp::verify() {
2055 AffineMap xPerm = getPermMap();
2058 return emitError(llvm::formatv(
"Expected rank(perm_map) > 1, got {0}", nx));
2062 llvm::formatv(
"Expected a permutation map, got {0}", xPerm));
2071 const auto checkDim = [&](Value v,
Size minSize,
2072 const char *message) -> LogicalResult {
2074 if (ShapedType::isStatic(sh) && sh < minSize)
2076 llvm::formatv(
"{0} got {1} < {2}", message, sh, minSize));
2079 uint64_t n = cn.value();
2081 if (
auto nyAttr = getNyAttr())
2082 ny = nyAttr.getInt();
2083 if (
failed(checkDim(getXy(), n * (nx + ny),
2084 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2086 for (Value opnd : getYs())
2087 if (
failed(checkDim(opnd, n,
"Expected dimension(y) >= n")))
2097IterSpaceType IteratorType::getIterSpaceType()
const {
2098 return IterSpaceType::get(
getContext(), getEncoding(), getLoLvl(),
2102IteratorType IterSpaceType::getIteratorType()
const {
2103 return IteratorType::get(
getContext(), getEncoding(), getLoLvl(), getHiLvl());
2122 "expect larger level upper bound than lower bound");
2130 IntegerAttr &lvlHiAttr) {
2147 p << lo <<
" to " << hi;
2153 IntegerAttr lvlHi) {
2154 unsigned lo = lvlLo.getValue().getZExtValue();
2155 unsigned hi = lvlHi.getValue().getZExtValue();
2166 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2169 ParseResult crdList =
2174 definedSet.
set(cnt);
2182 "parsed more value than expected.");
2184 if (failed(crdList)) {
2187 "expecting SSA value or \"_\" for level coordinates");
2189 assert(definedArgs.size() == definedSet.
count());
2196 if (definedSet.
empty())
2199 for (
unsigned i = 0; i < size; i++) {
2200 if (definedSet[i]) {
2201 p << blocksArgs.front();
2202 blocksArgs = blocksArgs.drop_front();
2209 assert(blocksArgs.empty());
2222 for (
auto &coord : coords)
2243 if (iterators.size() != spaces.size())
2246 "mismatch in number of sparse iterators and sparse spaces");
2251 size_t numCrds = coords.size();
2259 blockArgs.append(coords);
2265 if (iterSpaceTps.size() != spaces.size())
2267 "mismatch in number of iteration space operands "
2268 "and iteration space types");
2270 for (
auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2271 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2274 "expected sparse_tensor.iter_space type for "
2275 "iteration space operands");
2276 it.type = spaceTp.getIteratorType();
2291 if (args.size() != initArgs.size() || args.size() != state.
types.size()) {
2294 "mismatch in number of iteration arguments and return values");
2297 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.
types)) {
2319 size_t numCrds = coords.size();
2327 blockArgs.append(coords);
2335 if (iterSpaceTps.size() != spaces.size())
2337 "mismatch in number of iteration space operands "
2338 "and iteration space types");
2353 if (args.size() != initArgs.size() || args.size() != state.
types.size()) {
2356 "mismatch in number of iteration arguments and return values");
2359 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.
types)) {
2368LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2369 MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops,
2370 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2371 SmallVectorImpl<mlir::Type> &ret) {
2373 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2375 ret.push_back(IterSpaceType::get(ctx, stt.
getEncoding(), adaptor.getLoLvl(),
2376 adaptor.getHiLvl()));
2380LogicalResult ExtractIterSpaceOp::verify() {
2381 if (getLoLvl() >= getHiLvl())
2382 return emitOpError(
"expected smaller level low than level high");
2385 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2387 "parent iterator should be specified iff level lower bound equals 0");
2391 IterSpaceType spaceTp = getExtractedSpace().getType();
2392 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2394 "mismatch in parent iterator encoding and iteration space encoding.");
2396 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2397 return emitOpError(
"parent iterator should be used to extract an "
2398 "iteration space from a consecutive level.");
2404LogicalResult ExtractValOp::verify() {
2406 auto itTp = getIterator().getType();
2409 return emitOpError(
"mismatch in tensor encoding and iterator encoding.");
2412 return emitOpError(
"must use last-level iterator to extract values. ");
2423 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2424 for (
unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2425 if (
auto crd = iterateOp.getLvlCrd(i)) {
2426 if (crd->getUsers().empty())
2427 toRemove.set(crd->getArgNumber());
2434 if (toRemove.none())
2438 iterateOp.setCrdUsedLvls(newUsedLvls);
2439 iterateOp.getBody()->eraseArguments(toRemove);
2445void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2446 mlir::MLIRContext *context) {
2447 results.
add<RemoveUnusedLvlCrds>(context);
2450void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2452 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.
getType()).getSpaceDim();
2455 return build(builder, odsState, iterSpace, initArgs, set);
2458void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2461 OpBuilder::InsertionGuard guard(builder);
2467 Region *bodyRegion = odsState.
addRegion();
2472 for (Value v : initArgs)
2476 for (
unsigned i = 0, e = crdUsedLvls.
count(); i < e; i++)
2481 llvm::cast<IterSpaceType>(iterSpace.
getType()).getIteratorType(),
2485ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &
result) {
2486 OpAsmParser::Argument iterator;
2487 OpAsmParser::UnresolvedOperand iterSpace;
2489 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2492 if (iters.size() != 1)
2494 "expected only one iterator/iteration space");
2496 iterArgs.append(iters);
2497 Region *body =
result.addRegion();
2517 StringRef prefix =
"") {
2518 assert(blocksArgs.size() == initializers.size() &&
2519 "expected same length of arguments and initializers");
2520 if (initializers.empty())
2524 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
2525 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2530template <
typename SparseLoopOp>
2532 if (op.getInitArgs().size() != op.getNumResults()) {
2533 return op.emitOpError(
2534 "mismatch in number of loop-carried values and defined values");
2536 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2537 return op.emitOpError(
"required out-of-bound coordinates");
2545void IterateOp::print(OpAsmPrinter &p) {
2546 p <<
" " << getIterator() <<
" in " << getIterSpace();
2547 if (!getCrdUsedLvls().empty()) {
2554 p <<
" : " << getIterSpace().getType() <<
" ";
2555 if (!getInitArgs().empty())
2560 !getInitArgs().empty());
2563LogicalResult IterateOp::verifyRegions() {
2564 if (getIterator().
getType() != getIterSpace().
getType().getIteratorType())
2565 return emitOpError(
"mismatch in iterator and iteration space type");
2566 if (getNumRegionIterArgs() != getNumResults())
2568 "mismatch in number of basic block args and defined values");
2570 auto initArgs = getInitArgs();
2571 auto iterArgs = getRegionIterArgs();
2572 auto yieldVals = getYieldedValues();
2573 auto opResults = getResults();
2574 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2575 opResults.size()})) {
2576 return emitOpError() <<
"number mismatch between iter args and results.";
2579 for (
auto [i, init, iter, yield, ret] :
2580 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2581 if (init.getType() != ret.getType())
2582 return emitOpError() <<
"types mismatch between " << i
2583 <<
"th iter operand and defined value";
2584 if (iter.getType() != ret.getType())
2585 return emitOpError() <<
"types mismatch between " << i
2586 <<
"th iter region arg and defined value";
2587 if (yield.getType() != ret.getType())
2588 return emitOpError() <<
"types mismatch between " << i
2589 <<
"th yield value and defined value";
2596SmallVector<Region *> IterateOp::getLoopRegions() {
return {&getRegion()}; }
2598MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2599 return getInitArgsMutable();
2603 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2606std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2607 return cast<sparse_tensor::YieldOp>(
2608 getRegion().getBlocks().front().getTerminator())
2609 .getResultsMutable();
2612std::optional<ResultRange> IterateOp::getLoopResults() {
return getResults(); }
2614OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2615 return getInitArgs();
2618void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2619 SmallVectorImpl<RegionSuccessor> ®ions) {
2622 regions.push_back(RegionSuccessor(&getRegion()));
2627ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2632void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2634 unsigned numCases) {
2636 cast<IterSpaceType>(iterSpaces.front().
getType()).getSpaceDim();
2643 SmallVector<int64_t> caseBits(numCases, 0);
2645 return CoIterateOp::build(builder, odsState, initArgs.
getTypes(), iterSpaces,
2646 initArgs, set, cases,
2650ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &
result) {
2652 SmallVector<Value> spaces;
2655 SmallVector<OpAsmParser::Argument> blockArgs;
2659 result.addAttribute(
"operandSegmentSizes",
2661 {static_cast<int32_t>(spaces.size()),
2662 static_cast<int32_t>(result.types.size())}));
2664 SmallVector<Attribute> cases;
2668 SmallVector<OpAsmParser::Argument> definedIts;
2675 for (
auto [i, definedIdx] : llvm::enumerate(definedItSet.
bits())) {
2677 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].
getType());
2678 definedIts[i].type = spaceTp.getIteratorType();
2680 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2681 Region *body =
result.addRegion();
2685 CoIterateOp::ensureTerminator(*body, parser.
getBuilder(),
result.location);
2697void CoIterateOp::print(OpAsmPrinter &p) {
2699 llvm::interleaveComma(getIterSpaces(), p, [&](
auto s) { p << s; });
2702 if (!getCrdUsedLvls().empty()) {
2710 p <<
" : (" << getIterSpaces().getTypes() <<
")";
2711 if (!getInitArgs().empty())
2712 p.printArrowTypeList(getInitArgs().getTypes());
2714 for (
unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2718 getRegionDefinedSpace(idx));
2720 p.printRegion(getRegion(idx),
false,
2721 !getInitArgs().empty());
2725ValueRange CoIterateOp::getYieldedValues(
unsigned regionIdx) {
2726 return cast<sparse_tensor::YieldOp>(
2727 getRegion(regionIdx).getBlocks().front().getTerminator())
2731LogicalResult CoIterateOp::verifyRegions() {
2732 for (
unsigned r = 0, e = getNumRegions(); r < e; r++) {
2733 if (getNumRegionIterArgs() != getNumResults())
2735 "mismatch in number of basic block args and defined values");
2737 auto initArgs = getInitArgs();
2738 auto iterArgs = getRegionIterArgs(r);
2739 auto yieldVals = getYieldedValues(r);
2740 auto opResults = getResults();
2741 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2742 opResults.size()})) {
2744 <<
"number mismatch between iter args and results on " << r
2748 for (
auto [i, init, iter, yield, ret] :
2749 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2750 if (init.getType() != ret.getType())
2752 <<
"types mismatch between " << i
2753 <<
"th iter operand and defined value on " << r <<
"th region";
2754 if (iter.getType() != ret.getType())
2755 return emitOpError() <<
"types mismatch between " << i
2756 <<
"th iter region arg and defined value on " << r
2758 if (yield.getType() != ret.getType())
2760 <<
"types mismatch between " << i
2761 <<
"th yield value and defined value on " << r <<
"th region";
2765 auto cases = getRegionDefinedSpaces();
2766 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2767 if (set.size() != getNumRegions())
2773SmallVector<Region *> CoIterateOp::getSubCasesOf(
unsigned regionIdx) {
2774 SmallVector<Region *> ret;
2775 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2776 for (Region &r : getCaseRegions())
2777 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2789Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2790 Attribute value, Type type,
2792 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2797void SparseTensorDialect::initialize() {
2799#define GET_ATTRDEF_LIST
2800#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2803#define GET_TYPEDEF_LIST
2804#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2808#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2810 declarePromisedInterfaces<
2811 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2812 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2813 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2816#define GET_OP_CLASSES
2817#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2819#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
bool isSingletonLvl(Level l) const
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
MLIRContext * getContext() const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCompressedLvl(Level l) const
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
bool isUniqueLT(LevelType lt)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const