26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
51 return llvm::hash_value(
static_cast<uint64_t
>(lt));
79 if (dimShape.has_value()) {
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
88 memrefShape.push_back(ShapedType::kDynamic);
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
111 for (
Level l = 0; l < lvlRank; ) {
112 const auto lt = lvlTypes[l];
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
125 l = cooSegsRef.front().lvlRange.second;
131 cooSegsRef = cooSegsRef.drop_front();
159 const Type posMemType = MemRefType::get(memrefShape, stt.
getPosType());
161 const Type crdMemType = MemRefType::get(memrefShape, stt.
getCrdType());
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
179 llvm_unreachable(
"unrecognized field kind");
184 unsigned numFields = 0;
194 unsigned numFields = 0;
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl)
const {
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217 stride = lvlRank - cooStart;
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(
int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(
static_cast<uint64_t
>(v));
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset()
const {
245 return getStatic(getOffset());
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride()
const {
249 return getStatic(getStride());
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize()
const {
253 return getStatic(getSize());
256bool SparseTensorDimSliceAttr::isCompletelyDynamic()
const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ?
"?" : std::to_string(v);
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os)
const {
266 assert(getImpl() &&
"Uninitialized SparseTensorDimSliceAttr");
268 os << getStaticString(getOffset());
270 os << getStaticString(getSize());
272 os << getStaticString(getStride());
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer)
const {
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() &&
result < 0) {
287 "expect positive value or ? for slice offset/size/stride");
290 return parseResult.value();
294 result = SparseTensorDimSliceAttr::kDynamic;
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
311 offset, size, stride);
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() <<
"expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() <<
"expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() <<
"expect positive value or ? for slice stride";
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl)
const {
328 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc)
const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl()
const {
340 return withDimToLvl(AffineMap());
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(
unsigned posWidth,
345 unsigned crdWidth)
const {
346 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths()
const {
353 return withBitWidths(0, 0);
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal)
const {
358 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal()
const {
365 return withExplicitVal(Attribute());
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal)
const {
370 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal()
const {
377 return withImplicitVal(Attribute());
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices)
const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices()
const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank()
const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
397bool SparseTensorEncodingAttr::isAllDense()
const {
398 return !getImpl() || llvm::all_of(getLvlTypes(),
isDenseLT);
401bool SparseTensorEncodingAttr::isAllOrdered()
const {
402 return !getImpl() || llvm::all_of(getLvlTypes(),
isOrderedLT);
405Type SparseTensorEncodingAttr::getCrdElemType()
const {
409 return IntegerType::get(
getContext(), getCrdWidth());
413Type SparseTensorEncodingAttr::getPosElemType()
const {
417 return IntegerType::get(
getContext(), getPosWidth());
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape)
const {
424 return MemRefType::get(shape, getCrdElemType());
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape)
const {
430 return MemRefType::get(shape, getPosElemType());
433bool SparseTensorEncodingAttr::isIdentity()
const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437bool SparseTensorEncodingAttr::isPermutation()
const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441Dimension SparseTensorEncodingAttr::getDimRank()
const {
442 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.
getNumDims() : getLvlRank();
447Level SparseTensorEncodingAttr::getLvlRank()
const {
448 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
455 assert(l < getLvlRank() &&
"Level is out of bounds");
456 return getLvlTypes()[l];
459bool SparseTensorEncodingAttr::isSlice()
const {
460 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(
Dimension dim)
const {
466 assert(isSlice() &&
"Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() &&
"Dimension is out of bounds");
469 return dimSlices[dim];
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(
Dimension dim)
const {
474 return getDimSlice(dim).getStaticOffset();
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(
Dimension dim)
const {
479 return getDimSlice(dim).getStaticStride();
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(
Level lvl)
const {
484 return getStaticDimSliceOffset(
toDim(*
this, lvl));
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(
Level lvl)
const {
489 return getStaticDimSliceStride(
toDim(*
this, lvl));
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir)
const {
496 return SmallVector<int64_t>(srcShape);
498 SmallVector<int64_t> ret;
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
504 for (
unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ?
toDim(*
this, r)
507 ret.push_back(srcShape[trans]);
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(sz)) {
531 unsigned numSymbols = getDimToLvl().getNumSymbols();
533 for (AffineExpr exp : transMap.
getResults()) {
536 srcShape.size(), numSymbols);
538 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
539 ret.push_back(c.getValue() + 1);
541 if (
auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
545 if (
auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
546 ret.push_back(bound.getValue());
550 ret.push_back(ShapedType::kDynamic);
553 assert(ret.size() == rank);
558SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
560 CrdTransDirectionKind dir)
const {
564 SmallVector<Type> retType(
565 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
568 CrdTranslateOp::create(builder, loc, retType, crds, dir, *
this);
569 return transOp.getOutCrds();
572Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
580 SmallVector<LevelType> lvlTypes;
581 SmallVector<SparseTensorDimSliceAttr> dimSlices;
582 AffineMap dimToLvl = {};
583 AffineMap lvlToDim = {};
584 unsigned posWidth = 0;
585 unsigned crdWidth = 0;
586 Attribute explicitVal;
587 Attribute implicitVal;
589 SmallVector<StringRef, 5> keys = {
"map",
"posWidth",
"crdWidth",
590 "explicitVal",
"implicitVal"};
593 auto *it = find(keys, attrName);
594 if (it == keys.end()) {
598 unsigned keyWordIndex = it - keys.begin();
603 switch (keyWordIndex) {
606 auto res = cParser.parseDimLvlMap();
609 const auto &dlm = *res;
611 const Level lvlRank = dlm.getLvlRank();
612 for (
Level lvl = 0; lvl < lvlRank; lvl++)
613 lvlTypes.push_back(dlm.getLvlType(lvl));
615 const Dimension dimRank = dlm.getDimRank();
616 for (
Dimension dim = 0; dim < dimRank; dim++)
617 dimSlices.push_back(dlm.getDimSlice(dim));
621 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
622 return static_cast<bool>(slice.getImpl());
624 if (llvm::any_of(dimSlices, isDefined)) {
625 const auto defaultSlice =
626 SparseTensorDimSliceAttr::get(parser.
getContext());
627 for (
Dimension dim = 0; dim < dimRank; dim++)
628 if (!isDefined(dimSlices[dim]))
629 dimSlices[dim] = defaultSlice;
634 dimToLvl = dlm.getDimToLvlMap(parser.
getContext());
635 lvlToDim = dlm.getLvlToDimMap(parser.
getContext());
642 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
645 "expected an integral position bitwidth");
648 posWidth = intAttr.getInt();
655 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
658 "expected an integral index bitwidth");
661 crdWidth = intAttr.getInt();
668 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
670 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
672 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
676 "expected a numeric value for explicitVal");
685 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
687 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
689 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
693 "expected a numeric value for implicitVal");
711 if (!lvlToDim || lvlToDim.
isEmpty()) {
714 return parser.
getChecked<SparseTensorEncodingAttr>(
715 parser.
getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
716 explicitVal, implicitVal, dimSlices);
719void SparseTensorEncodingAttr::print(AsmPrinter &printer)
const {
720 auto map =
static_cast<AffineMap
>(getDimToLvl());
724 printer <<
"<{ map = ";
725 printSymbols(map, printer);
727 printDimensions(map, printer, getDimSlices());
729 printLevels(map, printer, getLvlTypes());
733 printer <<
", posWidth = " << getPosWidth();
735 printer <<
", crdWidth = " << getCrdWidth();
736 if (getExplicitVal()) {
737 printer <<
", explicitVal = " << getExplicitVal();
739 if (getImplicitVal())
740 printer <<
", implicitVal = " << getImplicitVal();
744void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
745 AsmPrinter &printer)
const {
749 for (
unsigned i = 0, n = map.
getNumSymbols() - 1; i < n; i++)
750 printer <<
's' << i <<
", ";
756void SparseTensorEncodingAttr::printDimensions(
757 AffineMap &map, AsmPrinter &printer,
758 ArrayRef<SparseTensorDimSliceAttr> dimSlices)
const {
759 if (!dimSlices.empty()) {
760 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
761 printer <<
'd' << i <<
" : " << dimSlices[i] <<
", ";
763 printer <<
'd' << map.
getNumDims() - 1 <<
" : "
767 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
768 printer <<
'd' << i <<
", ";
774void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
775 ArrayRef<LevelType> lvlTypes)
const {
776 for (
unsigned i = 0, n = map.
getNumResults() - 1; i < n; i++) {
787LogicalResult SparseTensorEncodingAttr::verify(
789 AffineMap dimToLvl, AffineMap lvlToDim,
unsigned posWidth,
790 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
791 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
793 return emitError() <<
"unexpected position bitwidth: " << posWidth;
795 return emitError() <<
"unexpected coordinate bitwidth: " << crdWidth;
799 while (it != lvlTypes.end()) {
800 if (it == lvlTypes.begin() ||
802 return emitError() <<
"expected compressed or loose_compressed level "
803 "before singleton level";
805 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(),
isSingletonLT);
807 return emitError() <<
"expected all singleton lvlTypes "
808 "following a singleton level";
810 if (!std::all_of(it, curCOOEnd, [it](
LevelType i) {
814 return emitError() <<
"expected all singleton lvlTypes stored in the "
815 "same memory layout (SoA vs AoS).";
820 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
821 if (!std::all_of(lastBatch, lvlTypes.rend(),
isBatchLT))
822 return emitError() <<
"Batch lvlType can only be leading levels.";
825 auto soaLvls = llvm::make_filter_range(lvlTypes, [](
LevelType lt) {
828 if (llvm::any_of(soaLvls, [](
LevelType lt) {
831 return emitError() <<
"SoA is only applicable to singleton lvlTypes.";
838 for (
auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
840 return emitError() <<
"dense level cannot follow a non-unique level";
844 if (
auto it = llvm::find_if(lvlTypes,
isNOutOfMLT);
845 it != std::end(lvlTypes)) {
846 if (it != lvlTypes.end() - 1)
847 return emitError() <<
"expected n_out_of_m to be the last level type";
848 if (!std::all_of(lvlTypes.begin(), it,
isDenseLT))
849 return emitError() <<
"expected all dense lvlTypes "
850 "before a n_out_of_m level";
854 <<
"expected 1xm block structure for n_out_of_m level";
857 unsigned coefficient = 0;
858 for (
const auto &elem : sizes) {
860 if (elem != coefficient && coefficient != 0) {
861 return emitError() <<
"expected only one blocked level "
862 "with the same coefficients";
867 if (coefficient !=
getM(*it)) {
868 return emitError() <<
"expected coeffiencts of Affine expressions "
869 "to be equal to m of n_out_of_m level";
878 const Level lvlRank = lvlTypes.size();
880 return emitError() <<
"expected a non-empty array for lvlTypes";
886 <<
"level-rank mismatch between dimToLvl and lvlTypes: "
891 return emitError() <<
"failed to infer lvlToDim from dimToLvl";
892 if (lvlToDim && (inferRes != lvlToDim))
893 return emitError() <<
"expected lvlToDim to be an inverse of dimToLvl";
894 if (dimRank > lvlRank)
895 return emitError() <<
"unexpected dimToLvl mapping from " << dimRank
896 <<
" to " << lvlRank;
898 if (!dimSlices.empty()) {
899 if (dimSlices.size() != dimRank)
901 <<
"dimension-rank mismatch between dimSlices and dimToLvl: "
902 << dimSlices.size() <<
" != " << dimRank;
905 if (dimRank != lvlRank)
907 <<
"dimSlices expected dimension-rank to match level-rank: "
908 << dimRank <<
" != " << lvlRank;
913LogicalResult SparseTensorEncodingAttr::verifyEncoding(
914 ArrayRef<Size> dimShape, Type elementType,
919 getPosWidth(), getCrdWidth(), getExplicitVal(),
920 getImplicitVal(), getDimSlices())))
925 const Dimension dimRank = dimShape.size();
927 return emitError() <<
"expected non-scalar sparse tensor";
928 if (getDimRank() != dimRank)
930 <<
"dimension-rank mismatch between encoding and tensor shape: "
931 << getDimRank() <<
" != " << dimRank;
932 if (
auto expVal = getExplicitVal()) {
933 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
934 if (attrType != elementType) {
935 return emitError() <<
"explicit value type mismatch between encoding and "
936 <<
"tensor element type: " << attrType
937 <<
" != " << elementType;
940 if (
auto impVal = getImplicitVal()) {
941 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
942 if (attrType != elementType) {
943 return emitError() <<
"implicit value type mismatch between encoding and "
944 <<
"tensor element type: " << attrType
945 <<
" != " << elementType;
948 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
949 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
950 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
951 if ((impFVal && impFVal.getValue().isNonZero()) ||
952 (impIntVal && !impIntVal.getValue().isZero()) ||
953 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
954 impComplexVal.getReal().isNonZero()))) {
955 return emitError() <<
"implicit value must be zero";
961Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart()
const {
962 SmallVector<COOSegment> coo = getCOOSegments();
963 assert(coo.size() == 1 || coo.empty());
964 if (!coo.empty() && coo.front().isAoS()) {
965 return coo.front().lvlRange.first;
970SmallVector<COOSegment>
971mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments()
const {
972 SmallVector<COOSegment> ret;
973 if (getLvlRank() <= 1)
976 ArrayRef<LevelType> lts = getLvlTypes();
978 while (l < getLvlRank()) {
981 auto cur = lts.begin() + l;
982 auto end = std::find_if(cur + 1, lts.end(), [](
LevelType lt) {
983 return !lt.isa<LevelFormat::Singleton>();
985 unsigned cooLen = std::distance(cur, end);
991 ret.push_back(
COOSegment{std::make_pair(l, l + cooLen),
1012 for (
Level l = startLvl + 1; l < lvlRank; ++l)
1024 lvlTypes.reserve(lvlRank);
1031 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1036 auto enc = SparseTensorEncodingAttr::get(
1046SparseTensorEncodingAttr
1048 if (
auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1049 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1050 if (
auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1051 return mdtp.getEncoding();
1057 auto map =
static_cast<AffineMap>(dimToLvl);
1074 lvlExprs.reserve(numLvls);
1077 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1078 for (
unsigned i = 0, n = numLvls; i < n; i++) {
1080 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1083 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1084 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1085 "expected only one floordiv for each dimension");
1090 components.push_back(binOp.getRHS());
1092 lvlExprComponents[pos] = components;
1094 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1095 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1096 "expected floordiv before mod");
1101 assert(
false &&
"expected floordiv or mod");
1111 for (
auto &components : lvlExprComponents) {
1112 assert(components.second.size() == 3 &&
1113 "expected 3 components to build lvlExprs");
1118 lvlExprs.push_back(addOp);
1125 "expected dimToLvl to be block sparsity for calling getBlockSize");
1128 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1130 blockSize.push_back(
1131 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1134 blockSize.push_back(0);
1143 std::map<unsigned, int64_t> coeffientMap;
1144 bool hasBlock =
false;
1146 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1148 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1149 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1150 if (!dimOp || !conOp || conOp.getValue() <= 0)
1153 auto pos = dimOp.getPosition();
1156 auto [it,
inserted] = coeffientMap.try_emplace(pos);
1160 it->second = conOp.getValue();
1163 auto it = coeffientMap.find(pos);
1164 if (it == coeffientMap.end())
1167 if (conOp.getValue() != it->second)
1173 }
else if (
auto dimOp = dyn_cast<AffineDimExpr>(
result)) {
1174 auto pos = dimOp.getPosition();
1176 if (!coeffientMap.try_emplace(pos, 0).second)
1186 auto hasNonIdentityMap = [](
Value v) {
1191 return llvm::any_of(op->
getOperands(), hasNonIdentityMap) ||
1192 llvm::any_of(op->
getResults(), hasNonIdentityMap);
1197 assert(enc.isPermutation() &&
"Non permutation map not supported");
1198 if (
const auto dimToLvl = enc.getDimToLvl())
1206 assert(enc.isPermutation() &&
"Non permutation map not supported");
1207 if (
const auto lvlToDim = enc.getLvlToDim())
1217static SparseTensorEncodingAttr
1220 for (
auto lt : enc.getLvlTypes())
1223 return SparseTensorEncodingAttr::get(
1224 enc.getContext(), lts,
1234 enc.getDimSlices());
1238StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1245 SparseTensorEncodingAttr encoding) {
1264 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1266 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1268 "redundant level argument for querying value memory size");
1271 const auto enc = md.getType().getEncoding();
1272 const Level lvlRank = enc.getLvlRank();
1274 if (mdKind == StorageSpecifierKind::DimOffset ||
1275 mdKind == StorageSpecifierKind::DimStride)
1277 return op->
emitError(
"requested slice data on non-slice tensor");
1279 if (mdKind != StorageSpecifierKind::ValMemSize) {
1281 return op->
emitError(
"missing level argument");
1283 const Level l = lvl.value();
1285 return op->
emitError(
"requested level is out of bounds");
1287 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1289 "requested position memory size on a singleton level");
1305 llvm_unreachable(
"Unrecognizable FieldKind");
1310 RankedTensorType valTp,
1313 return op->
emitError(
"the sparse-tensor must have static shape");
1315 return op->
emitError(
"the sparse-tensor must have an encoding attribute");
1321 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1323 unsigned expCOORank = stt.
getLvlRank() - cooStartLvl;
1324 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1325 return op->
emitError(
"input/output trailing COO level-ranks don't match");
1332 return op->
emitError(
"inconsistent number of fields between input/output");
1335 bool misMatch =
false;
1342 Type inputTp =
nullptr;
1346 assert(fid == idx && stt.
getLvlType(lvl) == lt);
1347 inputTp = lvlTps[idx++];
1350 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1352 if (inpElemTp != expElemTp) {
1360 return op->
emitError(
"input/output element-types don't match");
1364LogicalResult AssembleOp::verify() {
1365 RankedTensorType valuesTp = getValues().getType();
1366 const auto lvlsTp = getLevels().getTypes();
1371LogicalResult DisassembleOp::verify() {
1373 return emitError(
"output values and return value type mismatch");
1375 for (
auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1376 if (ot.getType() != rt.getType())
1377 return emitError(
"output levels and return levels type mismatch");
1379 RankedTensorType valuesTp = getRetValues().getType();
1380 const auto lvlsTp = getRetLevels().getTypes();
1385LogicalResult ConvertOp::verify() {
1386 RankedTensorType tp1 = getSource().getType();
1387 RankedTensorType tp2 = getDest().getType();
1388 if (tp1.getRank() != tp2.getRank())
1389 return emitError(
"unexpected conversion mismatch in rank");
1391 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1392 if (dstEnc && dstEnc.isSlice())
1393 return emitError(
"cannot convert to a sparse tensor slice");
1395 auto shape1 = tp1.getShape();
1396 auto shape2 = tp2.getShape();
1400 for (
Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1401 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1402 return emitError(
"unexpected conversion mismatch in dimension ") << d;
1406OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1412bool ConvertOp::needsExtraSort() {
1431 if (
auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1432 if (isa<SparseElementsAttr>(constOp.getValue()))
1438LogicalResult CrdTranslateOp::verify() {
1439 uint64_t inRank = getEncoder().getLvlRank();
1440 uint64_t outRank = getEncoder().getDimRank();
1442 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1443 std::swap(inRank, outRank);
1445 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1446 return emitError(
"Coordinate rank mismatch with encoding");
1451LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1452 SmallVectorImpl<OpFoldResult> &results) {
1453 if (getEncoder().isIdentity()) {
1454 results.assign(getInCrds().begin(), getInCrds().end());
1458 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1459 ? getEncoder().getDimToLvl()
1460 : getEncoder().getLvlToDim();
1462 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1467 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1468 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1474 bool oppositeDir = def.getDirection() != getDirection();
1476 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1477 bool sameCount = def.getNumResults() == getInCrds().size();
1478 if (!oppositeDir || !sameOracle || !sameCount)
1483 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1484 [](
auto valuePair) {
1485 auto [
lhs,
rhs] = valuePair;
1493 results.append(def.getInCrds().begin(), def.getInCrds().end());
1497void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1500 return build(builder, state, source, val);
1503LogicalResult LvlOp::verify() {
1504 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1506 if (
static_cast<uint64_t
>(lvl.value()) >= stt.
getLvlRank())
1508 "Level index exceeds the rank of the input sparse tensor");
1513std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1523 cast<RankedTensorType>(getSource().
getType()).getRank());
1527OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1528 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1532 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1542 auto getIndexAttr = [
this](int64_t lvlSz) {
1543 return IntegerAttr::get(IndexType::get(
getContext()), APInt(64, lvlSz));
1547 if (ShapedType::isStatic(lvlShape[lvl]))
1548 return getIndexAttr(lvlShape[lvl]);
1553void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1554 SparseTensorEncodingAttr dstEnc, Value source) {
1556 SmallVector<int64_t> srcLvlShape = srcStt.
getLvlShape();
1557 SmallVector<int64_t> dstDimShape =
1558 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1560 RankedTensorType::get(dstDimShape, srcStt.
getElementType(), dstEnc);
1561 return build(odsBuilder, odsState, dstTp, source);
1564LogicalResult ReinterpretMapOp::verify() {
1567 ArrayRef<LevelType> srcLvlTps = srcStt.
getLvlTypes();
1568 ArrayRef<LevelType> dstLvlTps = dstStt.
getLvlTypes();
1570 if (srcLvlTps.size() != dstLvlTps.size())
1571 return emitError(
"Level rank mismatch between source/dest tensors");
1573 for (
auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1574 if (srcLvlTp != dstLvlTp)
1575 return emitError(
"Level type mismatch between source/dest tensors");
1579 return emitError(
"Crd/Pos width mismatch between source/dest tensors");
1583 return emitError(
"Element type mismatch between source/dest tensors");
1585 SmallVector<Size> srcLvlShape = srcStt.
getLvlShape();
1586 SmallVector<Size> dstLvlShape = dstStt.
getLvlShape();
1587 for (
auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1588 if (srcLvlSz != dstLvlSz) {
1592 return emitError(
"Level size mismatch between source/dest tensors");
1599OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1603 if (
auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1605 if (def.getSource().getType() == getDest().
getType())
1606 return def.getSource();
1611template <
typename ToBufferOp>
1615 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1617 Type elemTp =
nullptr;
1618 bool withStride =
false;
1619 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1621 }
else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1622 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1624 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1626 }
else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1630 assert(elemTp &&
"unhandled operation.");
1632 bufShape.push_back(ShapedType::kDynamic);
1634 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1636 {ShapedType::kDynamic})
1637 : StridedLayoutAttr();
1638 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1642LogicalResult ToPositionsOp::verify() {
1645 return emitError(
"requested level is out of bounds");
1647 return emitError(
"unexpected type for positions");
1652ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1654 PropertyRef prop, RegionRange region,
1655 SmallVectorImpl<mlir::Type> &ret) {
1659LogicalResult ToCoordinatesOp::verify() {
1662 return emitError(
"requested level is out of bounds");
1664 return emitError(
"unexpected type for coordinates");
1669ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1671 PropertyRef prop, RegionRange region,
1672 SmallVectorImpl<mlir::Type> &ret) {
1676LogicalResult ToCoordinatesBufferOp::verify() {
1679 return emitError(
"expected sparse tensor with a COO region");
1683LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1684 MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops,
1685 DictionaryAttr attr, PropertyRef prop, RegionRange region,
1686 SmallVectorImpl<mlir::Type> &ret) {
1691LogicalResult ToValuesOp::verify() {
1695 return emitError(
"unexpected mismatch in element types");
1699LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1700 std::optional<Location> loc,
1702 PropertyRef prop, RegionRange region,
1703 SmallVectorImpl<mlir::Type> &ret) {
1707LogicalResult ToSliceOffsetOp::verify() {
1708 auto rank =
getSlice().getType().getRank();
1709 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1710 return emitError(
"requested dimension out of bound");
1714LogicalResult ToSliceStrideOp::verify() {
1715 auto rank =
getSlice().getType().getRank();
1716 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1717 return emitError(
"requested dimension out of bound");
1721LogicalResult GetStorageSpecifierOp::verify() {
1723 getSpecifier(), getOperation());
1726template <
typename SpecifierOp>
1728 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1731OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1732 const StorageSpecifierKind kind = getSpecifierKind();
1733 const auto lvl = getLevel();
1735 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1736 return op.getValue();
1740LogicalResult SetStorageSpecifierOp::verify() {
1742 getSpecifier(), getOperation());
1747 const char *regionName,
1750 unsigned expectedNum = inputTypes.size();
1751 if (numArgs != expectedNum)
1752 return op->emitError() << regionName <<
" region must have exactly "
1753 << expectedNum <<
" arguments";
1755 for (
unsigned i = 0; i < numArgs; i++) {
1757 if (typ != inputTypes[i])
1758 return op->emitError() << regionName <<
" region argument " << (i + 1)
1759 <<
" type mismatch";
1763 return op->emitError() << regionName
1764 <<
" region must end with a terminator";
1767 YieldOp yield = dyn_cast<YieldOp>(term);
1769 return op->emitError() << regionName
1770 <<
" region must end with sparse_tensor.yield";
1771 if (!yield.hasSingleResult() ||
1772 yield.getSingleResult().getType() != outputType)
1773 return op->emitError() << regionName <<
" region yield type mismatch";
1778LogicalResult BinaryOp::verify() {
1779 NamedAttrList attrs = (*this)->getAttrs();
1780 Type leftType = getX().getType();
1781 Type rightType = getY().getType();
1782 Type outputType = getOutput().getType();
1783 Region &overlap = getOverlapRegion();
1784 Region &left = getLeftRegion();
1785 Region &right = getRightRegion();
1789 if (!overlap.
empty()) {
1791 TypeRange{leftType, rightType}, outputType)))
1794 if (!left.
empty()) {
1798 }
else if (getLeftIdentity()) {
1799 if (leftType != outputType)
1800 return emitError(
"left=identity requires first argument to have the same "
1801 "type as the output");
1803 if (!right.
empty()) {
1807 }
else if (getRightIdentity()) {
1808 if (rightType != outputType)
1809 return emitError(
"right=identity requires second argument to have the "
1810 "same type as the output");
1815LogicalResult UnaryOp::verify() {
1816 Type inputType = getX().getType();
1817 Type outputType = getOutput().getType();
1821 Region &present = getPresentRegion();
1822 if (!present.
empty()) {
1827 Region &absent = getAbsentRegion();
1828 if (!absent.
empty()) {
1834 Block *parent = getOperation()->getBlock();
1836 cast<YieldOp>(absentBlock->
getTerminator()).getSingleResult();
1837 if (
auto arg = dyn_cast<BlockArgument>(absentVal)) {
1838 if (arg.getOwner() == parent)
1839 return emitError(
"absent region cannot yield linalg argument");
1841 if (!isa<arith::ConstantOp>(def) &&
1842 (def->getBlock() == absentBlock || def->getBlock() == parent))
1843 return emitError(
"absent region cannot yield locally computed value");
1849bool ConcatenateOp::needsExtraSort() {
1854 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1861 bool directLowerable =
1862 allSameOrdered && getDimension() == 0 && dstStt.
isIdentity();
1863 return !directLowerable;
1866LogicalResult ConcatenateOp::verify() {
1868 const Dimension concatDim = getDimension();
1869 const Dimension dimRank = dstTp.getDimRank();
1871 if (getInputs().size() <= 1)
1872 return emitError(
"Need at least two tensors to concatenate.");
1874 if (concatDim >= dimRank)
1876 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1877 concatDim, dimRank));
1879 for (
const auto &it : llvm::enumerate(getInputs())) {
1880 const auto i = it.index();
1882 if (srcTp.hasDynamicDimShape())
1883 return emitError(llvm::formatv(
"Input tensor ${0} has dynamic shape", i));
1884 const Dimension srcDimRank = srcTp.getDimRank();
1885 if (srcDimRank != dimRank)
1887 llvm::formatv(
"Input tensor ${0} has a different rank (rank={1}) "
1888 "from the output tensor (rank={2}).",
1889 i, srcDimRank, dimRank));
1892 for (
Dimension d = 0; d < dimRank; d++) {
1893 const Size dstSh = dstTp.getDimShape()[d];
1894 if (d == concatDim) {
1895 if (ShapedType::isStatic(dstSh)) {
1900 for (
const auto src : getInputs())
1906 "The concatenation dimension of the output tensor should be the "
1907 "sum of all the concatenation dimensions of the input tensors.");
1911 for (
const auto src : getInputs()) {
1913 if (ShapedType::isStatic(prev) && sh != prev)
1914 return emitError(
"All dimensions (expect for the concatenating one) "
1915 "should be equal.");
1924void PushBackOp::build(OpBuilder &builder, OperationState &
result,
1925 Value curSize, Value inBuffer, Value value) {
1926 build(builder,
result, curSize, inBuffer, value, Value());
1929LogicalResult PushBackOp::verify() {
1930 if (Value n =
getN()) {
1932 if (nValue && nValue.value() < 1)
1938LogicalResult CompressOp::verify() {
1940 if (stt.
getLvlRank() != 1 +
static_cast<Level>(getLvlCoords().size()))
1941 return emitOpError(
"incorrect number of coordinates");
1945void ForeachOp::build(
1946 OpBuilder &builder, OperationState &
result, Value tensor,
1950 build(builder,
result, initArgs.
getTypes(), tensor, initArgs, order);
1958 SmallVector<Type> blockArgTypes(dimRank, builder.
getIndexType());
1962 blockArgTypes.append(initArgs.
getTypes().begin(), initArgs.
getTypes().end());
1964 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.
getLoc());
1966 OpBuilder::InsertionGuard guard(builder);
1967 auto ®ion = *
result.regions.front();
1969 builder.
createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1970 bodyBuilder(builder,
result.location,
1976LogicalResult ForeachOp::verify() {
1978 const Dimension dimRank = t.getDimRank();
1979 const auto args = getBody()->getArguments();
1981 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1982 return emitError(
"Level traverse order does not match tensor's level rank");
1984 if (dimRank + 1 + getInitArgs().size() != args.size())
1985 return emitError(
"Unmatched number of arguments in the block");
1987 if (getNumResults() != getInitArgs().size())
1988 return emitError(
"Mismatch in number of init arguments and results");
1990 if (getResultTypes() != getInitArgs().getTypes())
1991 return emitError(
"Mismatch in types of init arguments and results");
1994 auto yield = cast<YieldOp>(getBody()->getTerminator());
1995 if (yield.getNumOperands() != getNumResults() ||
1996 yield.getOperands().getTypes() != getResultTypes())
1997 return emitError(
"Mismatch in types of yield values and results");
1999 const auto iTp = IndexType::get(
getContext());
2003 llvm::formatv(
"Expecting Index type for argument at index {0}", d));
2005 const auto elemTp = t.getElementType();
2006 const auto valueTp = args[dimRank].getType();
2007 if (elemTp != valueTp)
2009 llvm::formatv(
"Unmatched element type between input tensor and "
2010 "block argument, expected:{0}, got: {1}",
2015OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2018 return getInputCoo();
2023LogicalResult ReorderCOOOp::verify() {
2028 return emitError(
"Expected COO sparse tensors only");
2031 return emitError(
"Unmatched dim2lvl map between input and result COO");
2036 return emitError(
"Unmatched storage format between input and result COO");
2041LogicalResult ReduceOp::verify() {
2042 Type inputType = getX().getType();
2043 Region &formula = getRegion();
2045 TypeRange{inputType, inputType}, inputType);
2048LogicalResult SelectOp::verify() {
2050 Type inputType = getX().getType();
2051 Type boolType =
b.getI1Type();
2052 Region &formula = getRegion();
2057LogicalResult SortOp::verify() {
2058 AffineMap xPerm = getPermMap();
2061 return emitError(llvm::formatv(
"Expected rank(perm_map) > 1, got {0}", nx));
2065 llvm::formatv(
"Expected a permutation map, got {0}", xPerm));
2074 const auto checkDim = [&](Value v,
Size minSize,
2075 const char *message) -> LogicalResult {
2077 if (ShapedType::isStatic(sh) && sh < minSize)
2079 llvm::formatv(
"{0} got {1} < {2}", message, sh, minSize));
2082 uint64_t n = cn.value();
2084 if (
auto nyAttr = getNyAttr())
2085 ny = nyAttr.getInt();
2086 if (
failed(checkDim(getXy(), n * (nx + ny),
2087 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2089 for (Value opnd : getYs())
2090 if (
failed(checkDim(opnd, n,
"Expected dimension(y) >= n")))
2100IterSpaceType IteratorType::getIterSpaceType()
const {
2101 return IterSpaceType::get(
getContext(), getEncoding(), getLoLvl(),
2105IteratorType IterSpaceType::getIteratorType()
const {
2106 return IteratorType::get(
getContext(), getEncoding(), getLoLvl(), getHiLvl());
2125 "expect larger level upper bound than lower bound");
2133 IntegerAttr &lvlHiAttr) {
2150 p << lo <<
" to " << hi;
2156 IntegerAttr lvlHi) {
2157 unsigned lo = lvlLo.getValue().getZExtValue();
2158 unsigned hi = lvlHi.getValue().getZExtValue();
2169 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2172 ParseResult crdList =
2177 definedSet.
set(cnt);
2185 "parsed more value than expected.");
2187 if (failed(crdList)) {
2190 "expecting SSA value or \"_\" for level coordinates");
2192 assert(definedArgs.size() == definedSet.
count());
2199 if (definedSet.
empty())
2202 for (
unsigned i = 0; i < size; i++) {
2203 if (definedSet[i]) {
2204 p << blocksArgs.front();
2205 blocksArgs = blocksArgs.drop_front();
2212 assert(blocksArgs.empty());
2225 for (
auto &coord : coords)
2246 if (iterators.size() != spaces.size())
2249 "mismatch in number of sparse iterators and sparse spaces");
2254 size_t numCrds = coords.size();
2262 blockArgs.append(coords);
2268 if (iterSpaceTps.size() != spaces.size())
2270 "mismatch in number of iteration space operands "
2271 "and iteration space types");
2273 for (
auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2274 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2277 "expected sparse_tensor.iter_space type for "
2278 "iteration space operands");
2279 it.type = spaceTp.getIteratorType();
2294 if (args.size() != initArgs.size() || args.size() != state.
types.size()) {
2297 "mismatch in number of iteration arguments and return values");
2300 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.
types)) {
2322 size_t numCrds = coords.size();
2330 blockArgs.append(coords);
2338 if (iterSpaceTps.size() != spaces.size())
2340 "mismatch in number of iteration space operands "
2341 "and iteration space types");
2356 if (args.size() != initArgs.size() || args.size() != state.
types.size()) {
2359 "mismatch in number of iteration arguments and return values");
2362 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.
types)) {
2371LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2372 MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops,
2373 DictionaryAttr attr, PropertyRef prop, RegionRange region,
2374 SmallVectorImpl<mlir::Type> &ret) {
2376 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2378 ret.push_back(IterSpaceType::get(ctx, stt.
getEncoding(), adaptor.getLoLvl(),
2379 adaptor.getHiLvl()));
2383LogicalResult ExtractIterSpaceOp::verify() {
2384 if (getLoLvl() >= getHiLvl())
2385 return emitOpError(
"expected smaller level low than level high");
2388 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2390 "parent iterator should be specified iff level lower bound equals 0");
2394 IterSpaceType spaceTp = getExtractedSpace().getType();
2395 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2397 "mismatch in parent iterator encoding and iteration space encoding.");
2399 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2400 return emitOpError(
"parent iterator should be used to extract an "
2401 "iteration space from a consecutive level.");
2407LogicalResult ExtractValOp::verify() {
2409 auto itTp = getIterator().getType();
2412 return emitOpError(
"mismatch in tensor encoding and iterator encoding.");
2415 return emitOpError(
"must use last-level iterator to extract values. ");
2426 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2427 for (
unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2428 if (
auto crd = iterateOp.getLvlCrd(i)) {
2429 if (crd->getUsers().empty())
2430 toRemove.set(crd->getArgNumber());
2437 if (toRemove.none())
2441 iterateOp.setCrdUsedLvls(newUsedLvls);
2442 iterateOp.getBody()->eraseArguments(toRemove);
2448void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2449 mlir::MLIRContext *context) {
2450 results.
add<RemoveUnusedLvlCrds>(context);
2453void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2455 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.
getType()).getSpaceDim();
2458 return build(builder, odsState, iterSpace, initArgs, set);
2461void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2464 OpBuilder::InsertionGuard guard(builder);
2470 Region *bodyRegion = odsState.
addRegion();
2475 for (Value v : initArgs)
2479 for (
unsigned i = 0, e = crdUsedLvls.
count(); i < e; i++)
2484 llvm::cast<IterSpaceType>(iterSpace.
getType()).getIteratorType(),
2488ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &
result) {
2489 OpAsmParser::Argument iterator;
2490 OpAsmParser::UnresolvedOperand iterSpace;
2492 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2495 if (iters.size() != 1)
2497 "expected only one iterator/iteration space");
2499 iterArgs.append(iters);
2500 Region *body =
result.addRegion();
2520 StringRef prefix =
"") {
2521 assert(blocksArgs.size() == initializers.size() &&
2522 "expected same length of arguments and initializers");
2523 if (initializers.empty())
2527 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
2528 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2533template <
typename SparseLoopOp>
2535 if (op.getInitArgs().size() != op.getNumResults()) {
2536 return op.emitOpError(
2537 "mismatch in number of loop-carried values and defined values");
2539 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2540 return op.emitOpError(
"required out-of-bound coordinates");
2548void IterateOp::print(OpAsmPrinter &p) {
2549 p <<
" " << getIterator() <<
" in " << getIterSpace();
2550 if (!getCrdUsedLvls().empty()) {
2557 p <<
" : " << getIterSpace().getType() <<
" ";
2558 if (!getInitArgs().empty())
2563 !getInitArgs().empty());
2566LogicalResult IterateOp::verifyRegions() {
2567 if (getIterator().
getType() != getIterSpace().
getType().getIteratorType())
2568 return emitOpError(
"mismatch in iterator and iteration space type");
2569 if (getNumRegionIterArgs() != getNumResults())
2571 "mismatch in number of basic block args and defined values");
2573 auto initArgs = getInitArgs();
2574 auto iterArgs = getRegionIterArgs();
2575 auto yieldVals = getYieldedValues();
2576 auto opResults = getResults();
2577 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2578 opResults.size()})) {
2579 return emitOpError() <<
"number mismatch between iter args and results.";
2582 for (
auto [i, init, iter, yield, ret] :
2583 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2584 if (init.getType() != ret.getType())
2585 return emitOpError() <<
"types mismatch between " << i
2586 <<
"th iter operand and defined value";
2587 if (iter.getType() != ret.getType())
2588 return emitOpError() <<
"types mismatch between " << i
2589 <<
"th iter region arg and defined value";
2590 if (yield.getType() != ret.getType())
2591 return emitOpError() <<
"types mismatch between " << i
2592 <<
"th yield value and defined value";
2599SmallVector<Region *> IterateOp::getLoopRegions() {
return {&getRegion()}; }
2601MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2602 return getInitArgsMutable();
2606 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2609std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2610 return cast<sparse_tensor::YieldOp>(
2611 getRegion().getBlocks().front().getTerminator())
2612 .getResultsMutable();
2615std::optional<ResultRange> IterateOp::getLoopResults() {
return getResults(); }
2617OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2618 return getInitArgs();
2621void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2622 SmallVectorImpl<RegionSuccessor> ®ions) {
2625 regions.push_back(RegionSuccessor(&getRegion()));
2630ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2635void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2637 unsigned numCases) {
2639 cast<IterSpaceType>(iterSpaces.front().
getType()).getSpaceDim();
2646 SmallVector<int64_t> caseBits(numCases, 0);
2648 return CoIterateOp::build(builder, odsState, initArgs.
getTypes(), iterSpaces,
2649 initArgs, set, cases,
2653ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &
result) {
2655 SmallVector<Value> spaces;
2658 SmallVector<OpAsmParser::Argument> blockArgs;
2662 result.addAttribute(
"operandSegmentSizes",
2664 {static_cast<int32_t>(spaces.size()),
2665 static_cast<int32_t>(result.types.size())}));
2667 SmallVector<Attribute> cases;
2671 SmallVector<OpAsmParser::Argument> definedIts;
2678 for (
auto [i, definedIdx] : llvm::enumerate(definedItSet.
bits())) {
2680 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].
getType());
2681 definedIts[i].type = spaceTp.getIteratorType();
2683 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2684 Region *body =
result.addRegion();
2688 CoIterateOp::ensureTerminator(*body, parser.
getBuilder(),
result.location);
2700void CoIterateOp::print(OpAsmPrinter &p) {
2702 llvm::interleaveComma(getIterSpaces(), p, [&](
auto s) { p << s; });
2705 if (!getCrdUsedLvls().empty()) {
2713 p <<
" : (" << getIterSpaces().getTypes() <<
")";
2714 if (!getInitArgs().empty())
2715 p.printArrowTypeList(getInitArgs().getTypes());
2717 for (
unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2721 getRegionDefinedSpace(idx));
2723 p.printRegion(getRegion(idx),
false,
2724 !getInitArgs().empty());
2728ValueRange CoIterateOp::getYieldedValues(
unsigned regionIdx) {
2729 return cast<sparse_tensor::YieldOp>(
2730 getRegion(regionIdx).getBlocks().front().getTerminator())
2734LogicalResult CoIterateOp::verifyRegions() {
2735 for (
unsigned r = 0, e = getNumRegions(); r < e; r++) {
2736 if (getNumRegionIterArgs() != getNumResults())
2738 "mismatch in number of basic block args and defined values");
2740 auto initArgs = getInitArgs();
2741 auto iterArgs = getRegionIterArgs(r);
2742 auto yieldVals = getYieldedValues(r);
2743 auto opResults = getResults();
2744 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2745 opResults.size()})) {
2747 <<
"number mismatch between iter args and results on " << r
2751 for (
auto [i, init, iter, yield, ret] :
2752 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2753 if (init.getType() != ret.getType())
2755 <<
"types mismatch between " << i
2756 <<
"th iter operand and defined value on " << r <<
"th region";
2757 if (iter.getType() != ret.getType())
2758 return emitOpError() <<
"types mismatch between " << i
2759 <<
"th iter region arg and defined value on " << r
2761 if (yield.getType() != ret.getType())
2763 <<
"types mismatch between " << i
2764 <<
"th yield value and defined value on " << r <<
"th region";
2768 auto cases = getRegionDefinedSpaces();
2769 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2770 if (set.size() != getNumRegions())
2776SmallVector<Region *> CoIterateOp::getSubCasesOf(
unsigned regionIdx) {
2777 SmallVector<Region *> ret;
2778 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2779 for (Region &r : getCaseRegions())
2780 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2792Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2793 Attribute value, Type type,
2795 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2800void SparseTensorDialect::initialize() {
2802#define GET_ATTRDEF_LIST
2803#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2806#define GET_TYPEDEF_LIST
2807#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2811#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2813 declarePromisedInterfaces<
2814 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2815 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2816 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2819#define GET_OP_CLASSES
2820#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2822#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, PropertyRef prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
bool isSingletonLvl(Level l) const
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
MLIRContext * getContext() const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCompressedLvl(Level l) const
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
bool isUniqueLT(LevelType lt)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const