26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
51 return llvm::hash_value(
static_cast<uint64_t
>(lt));
79 if (dimShape.has_value()) {
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
88 memrefShape.push_back(ShapedType::kDynamic);
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
111 for (
Level l = 0; l < lvlRank; ) {
112 const auto lt = lvlTypes[l];
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
125 l = cooSegsRef.front().lvlRange.second;
131 cooSegsRef = cooSegsRef.drop_front();
159 const Type posMemType = MemRefType::get(memrefShape, stt.
getPosType());
161 const Type crdMemType = MemRefType::get(memrefShape, stt.
getCrdType());
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
179 llvm_unreachable(
"unrecognized field kind");
184 unsigned numFields = 0;
194 unsigned numFields = 0;
206std::pair<FieldIndex, unsigned>
208 std::optional<Level> lvl)
const {
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
217 stride = lvlRank - cooStart;
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(
int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(
static_cast<uint64_t
>(v));
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset()
const {
245 return getStatic(getOffset());
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride()
const {
249 return getStatic(getStride());
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize()
const {
253 return getStatic(getSize());
256bool SparseTensorDimSliceAttr::isCompletelyDynamic()
const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ?
"?" : std::to_string(v);
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os)
const {
266 assert(getImpl() &&
"Uninitialized SparseTensorDimSliceAttr");
268 os << getStaticString(getOffset());
270 os << getStaticString(getSize());
272 os << getStaticString(getStride());
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer)
const {
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() &&
result < 0) {
287 "expect positive value or ? for slice offset/size/stride");
290 return parseResult.value();
294 result = SparseTensorDimSliceAttr::kDynamic;
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
311 offset, size, stride);
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() <<
"expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() <<
"expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() <<
"expect positive value or ? for slice stride";
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl)
const {
328 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
331 getCrdWidth(), getExplicitVal(), getImplicitVal());
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc)
const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl()
const {
340 return withDimToLvl(AffineMap());
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(
unsigned posWidth,
345 unsigned crdWidth)
const {
346 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth, getExplicitVal(), getImplicitVal());
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths()
const {
353 return withBitWidths(0, 0);
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal)
const {
358 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
361 getCrdWidth(), explicitVal, getImplicitVal());
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal()
const {
365 return withExplicitVal(Attribute());
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal)
const {
370 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
373 getCrdWidth(), getExplicitVal(), implicitVal);
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal()
const {
377 return withImplicitVal(Attribute());
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices)
const {
382 return SparseTensorEncodingAttr::get(
383 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
384 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices()
const {
388 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank()
const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
394 return std::distance(lastBatch, lvlTypes.rend());
397bool SparseTensorEncodingAttr::isAllDense()
const {
398 return !getImpl() || llvm::all_of(getLvlTypes(),
isDenseLT);
401bool SparseTensorEncodingAttr::isAllOrdered()
const {
402 return !getImpl() || llvm::all_of(getLvlTypes(),
isOrderedLT);
405Type SparseTensorEncodingAttr::getCrdElemType()
const {
409 return IntegerType::get(
getContext(), getCrdWidth());
413Type SparseTensorEncodingAttr::getPosElemType()
const {
417 return IntegerType::get(
getContext(), getPosWidth());
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape)
const {
424 return MemRefType::get(shape, getCrdElemType());
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape)
const {
430 return MemRefType::get(shape, getPosElemType());
433bool SparseTensorEncodingAttr::isIdentity()
const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437bool SparseTensorEncodingAttr::isPermutation()
const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441Dimension SparseTensorEncodingAttr::getDimRank()
const {
442 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.
getNumDims() : getLvlRank();
447Level SparseTensorEncodingAttr::getLvlRank()
const {
448 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
455 assert(l < getLvlRank() &&
"Level is out of bounds");
456 return getLvlTypes()[l];
459bool SparseTensorEncodingAttr::isSlice()
const {
460 assert(getImpl() &&
"Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(
Dimension dim)
const {
466 assert(isSlice() &&
"Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() &&
"Dimension is out of bounds");
469 return dimSlices[dim];
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(
Dimension dim)
const {
474 return getDimSlice(dim).getStaticOffset();
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(
Dimension dim)
const {
479 return getDimSlice(dim).getStaticStride();
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(
Level lvl)
const {
484 return getStaticDimSliceOffset(
toDim(*
this, lvl));
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(
Level lvl)
const {
489 return getStaticDimSliceStride(
toDim(*
this, lvl));
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir)
const {
496 return SmallVector<int64_t>(srcShape);
498 SmallVector<int64_t> ret;
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
504 for (
unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ?
toDim(*
this, r)
507 ret.push_back(srcShape[trans]);
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
521 ret.resize(rank, ShapedType::kDynamic);
525 SmallVector<AffineExpr> dimRep;
526 dimRep.reserve(srcShape.size());
527 for (int64_t sz : srcShape) {
528 if (ShapedType::isStatic(sz)) {
540 unsigned numSymbols = getDimToLvl().getNumSymbols();
542 for (AffineExpr exp : transMap.
getResults()) {
545 srcShape.size(), numSymbols);
547 if (
auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
548 ret.push_back(c.getValue() + 1);
550 if (
auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
554 if (
auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
555 ret.push_back(bound.getValue());
559 ret.push_back(ShapedType::kDynamic);
562 assert(ret.size() == rank);
567SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
569 CrdTransDirectionKind dir)
const {
573 SmallVector<Type> retType(
574 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
577 CrdTranslateOp::create(builder, loc, retType, crds, dir, *
this);
578 return transOp.getOutCrds();
581Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
589 SmallVector<LevelType> lvlTypes;
590 SmallVector<SparseTensorDimSliceAttr> dimSlices;
591 AffineMap dimToLvl = {};
592 AffineMap lvlToDim = {};
593 unsigned posWidth = 0;
594 unsigned crdWidth = 0;
595 Attribute explicitVal;
596 Attribute implicitVal;
598 SmallVector<StringRef, 5> keys = {
"map",
"posWidth",
"crdWidth",
599 "explicitVal",
"implicitVal"};
602 auto *it = find(keys, attrName);
603 if (it == keys.end()) {
607 unsigned keyWordIndex = it - keys.begin();
612 switch (keyWordIndex) {
615 auto res = cParser.parseDimLvlMap();
618 const auto &dlm = *res;
620 const Level lvlRank = dlm.getLvlRank();
621 for (
Level lvl = 0; lvl < lvlRank; lvl++)
622 lvlTypes.push_back(dlm.getLvlType(lvl));
624 const Dimension dimRank = dlm.getDimRank();
625 for (
Dimension dim = 0; dim < dimRank; dim++)
626 dimSlices.push_back(dlm.getDimSlice(dim));
630 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
631 return static_cast<bool>(slice.getImpl());
633 if (llvm::any_of(dimSlices, isDefined)) {
634 const auto defaultSlice =
635 SparseTensorDimSliceAttr::get(parser.
getContext());
636 for (
Dimension dim = 0; dim < dimRank; dim++)
637 if (!isDefined(dimSlices[dim]))
638 dimSlices[dim] = defaultSlice;
643 dimToLvl = dlm.getDimToLvlMap(parser.
getContext());
644 lvlToDim = dlm.getLvlToDimMap(parser.
getContext());
651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
654 "expected an integral position bitwidth");
657 posWidth = intAttr.getInt();
664 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
667 "expected an integral index bitwidth");
670 crdWidth = intAttr.getInt();
677 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
679 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
681 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685 "expected a numeric value for explicitVal");
694 if (
auto result = llvm::dyn_cast<FloatAttr>(attr)) {
696 }
else if (
auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
698 }
else if (
auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
702 "expected a numeric value for implicitVal");
720 if (!lvlToDim || lvlToDim.
isEmpty()) {
723 return parser.
getChecked<SparseTensorEncodingAttr>(
724 parser.
getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
725 explicitVal, implicitVal, dimSlices);
728void SparseTensorEncodingAttr::print(AsmPrinter &printer)
const {
729 auto map =
static_cast<AffineMap
>(getDimToLvl());
733 printer <<
"<{ map = ";
734 printSymbols(map, printer);
736 printDimensions(map, printer, getDimSlices());
738 printLevels(map, printer, getLvlTypes());
742 printer <<
", posWidth = " << getPosWidth();
744 printer <<
", crdWidth = " << getCrdWidth();
745 if (getExplicitVal()) {
746 printer <<
", explicitVal = " << getExplicitVal();
748 if (getImplicitVal())
749 printer <<
", implicitVal = " << getImplicitVal();
753void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
754 AsmPrinter &printer)
const {
758 for (
unsigned i = 0, n = map.
getNumSymbols() - 1; i < n; i++)
759 printer <<
's' << i <<
", ";
765void SparseTensorEncodingAttr::printDimensions(
766 AffineMap &map, AsmPrinter &printer,
767 ArrayRef<SparseTensorDimSliceAttr> dimSlices)
const {
768 if (!dimSlices.empty()) {
769 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
770 printer <<
'd' << i <<
" : " << dimSlices[i] <<
", ";
772 printer <<
'd' << map.
getNumDims() - 1 <<
" : "
776 for (
unsigned i = 0, n = map.
getNumDims() - 1; i < n; i++)
777 printer <<
'd' << i <<
", ";
783void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
784 ArrayRef<LevelType> lvlTypes)
const {
785 for (
unsigned i = 0, n = map.
getNumResults() - 1; i < n; i++) {
796LogicalResult SparseTensorEncodingAttr::verify(
798 AffineMap dimToLvl, AffineMap lvlToDim,
unsigned posWidth,
799 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
800 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
802 return emitError() <<
"unexpected position bitwidth: " << posWidth;
804 return emitError() <<
"unexpected coordinate bitwidth: " << crdWidth;
808 while (it != lvlTypes.end()) {
809 if (it == lvlTypes.begin() ||
811 return emitError() <<
"expected compressed or loose_compressed level "
812 "before singleton level";
814 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(),
isSingletonLT);
816 return emitError() <<
"expected all singleton lvlTypes "
817 "following a singleton level";
819 if (!std::all_of(it, curCOOEnd, [it](
LevelType i) {
823 return emitError() <<
"expected all singleton lvlTypes stored in the "
824 "same memory layout (SoA vs AoS).";
829 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(),
isBatchLT);
830 if (!std::all_of(lastBatch, lvlTypes.rend(),
isBatchLT))
831 return emitError() <<
"Batch lvlType can only be leading levels.";
834 auto soaLvls = llvm::make_filter_range(lvlTypes, [](
LevelType lt) {
837 if (llvm::any_of(soaLvls, [](
LevelType lt) {
840 return emitError() <<
"SoA is only applicable to singleton lvlTypes.";
847 for (
auto [i, lt] : llvm::drop_begin(llvm::enumerate(lvlTypes))) {
849 return emitError() <<
"dense level cannot follow a non-unique level";
853 if (
auto it = llvm::find_if(lvlTypes,
isNOutOfMLT);
854 it != std::end(lvlTypes)) {
855 if (it != lvlTypes.end() - 1)
856 return emitError() <<
"expected n_out_of_m to be the last level type";
857 if (!std::all_of(lvlTypes.begin(), it,
isDenseLT))
858 return emitError() <<
"expected all dense lvlTypes "
859 "before a n_out_of_m level";
863 <<
"expected 1xm block structure for n_out_of_m level";
866 unsigned coefficient = 0;
867 for (
const auto &elem : sizes) {
869 if (elem != coefficient && coefficient != 0) {
870 return emitError() <<
"expected only one blocked level "
871 "with the same coefficients";
876 if (coefficient !=
getM(*it)) {
877 return emitError() <<
"expected coeffiencts of Affine expressions "
878 "to be equal to m of n_out_of_m level";
887 const Level lvlRank = lvlTypes.size();
889 return emitError() <<
"expected a non-empty array for lvlTypes";
895 <<
"level-rank mismatch between dimToLvl and lvlTypes: "
900 return emitError() <<
"failed to infer lvlToDim from dimToLvl";
901 if (lvlToDim && (inferRes != lvlToDim))
902 return emitError() <<
"expected lvlToDim to be an inverse of dimToLvl";
903 if (dimRank > lvlRank)
904 return emitError() <<
"unexpected dimToLvl mapping from " << dimRank
905 <<
" to " << lvlRank;
907 if (!dimSlices.empty()) {
908 if (dimSlices.size() != dimRank)
910 <<
"dimension-rank mismatch between dimSlices and dimToLvl: "
911 << dimSlices.size() <<
" != " << dimRank;
914 if (dimRank != lvlRank)
916 <<
"dimSlices expected dimension-rank to match level-rank: "
917 << dimRank <<
" != " << lvlRank;
922LogicalResult SparseTensorEncodingAttr::verifyEncoding(
923 ArrayRef<Size> dimShape, Type elementType,
928 getPosWidth(), getCrdWidth(), getExplicitVal(),
929 getImplicitVal(), getDimSlices())))
934 const Dimension dimRank = dimShape.size();
936 return emitError() <<
"expected non-scalar sparse tensor";
937 if (getDimRank() != dimRank)
939 <<
"dimension-rank mismatch between encoding and tensor shape: "
940 << getDimRank() <<
" != " << dimRank;
941 if (
auto expVal = getExplicitVal()) {
942 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
943 if (attrType != elementType) {
944 return emitError() <<
"explicit value type mismatch between encoding and "
945 <<
"tensor element type: " << attrType
946 <<
" != " << elementType;
949 if (
auto impVal = getImplicitVal()) {
950 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
951 if (attrType != elementType) {
952 return emitError() <<
"implicit value type mismatch between encoding and "
953 <<
"tensor element type: " << attrType
954 <<
" != " << elementType;
957 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
958 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
959 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
960 if ((impFVal && impFVal.getValue().isNonZero()) ||
961 (impIntVal && !impIntVal.getValue().isZero()) ||
962 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
963 impComplexVal.getReal().isNonZero()))) {
964 return emitError() <<
"implicit value must be zero";
970Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart()
const {
971 SmallVector<COOSegment> coo = getCOOSegments();
972 assert(coo.size() == 1 || coo.empty());
973 if (!coo.empty() && coo.front().isAoS()) {
974 return coo.front().lvlRange.first;
979SmallVector<COOSegment>
980mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments()
const {
981 SmallVector<COOSegment> ret;
982 if (getLvlRank() <= 1)
985 ArrayRef<LevelType> lts = getLvlTypes();
987 while (l < getLvlRank()) {
990 auto cur = lts.begin() + l;
991 auto end = std::find_if(cur + 1, lts.end(), [](
LevelType lt) {
992 return !lt.isa<LevelFormat::Singleton>();
994 unsigned cooLen = std::distance(cur, end);
1000 ret.push_back(
COOSegment{std::make_pair(l, l + cooLen),
1021 for (
Level l = startLvl + 1; l < lvlRank; ++l)
1033 lvlTypes.reserve(lvlRank);
1040 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
1045 auto enc = SparseTensorEncodingAttr::get(
1055SparseTensorEncodingAttr
1057 if (
auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1058 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1059 if (
auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1060 return mdtp.getEncoding();
1066 auto map =
static_cast<AffineMap>(dimToLvl);
1083 lvlExprs.reserve(numLvls);
1086 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1087 for (
unsigned i = 0, n = numLvls; i < n; i++) {
1089 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1092 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1093 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1094 "expected only one floordiv for each dimension");
1099 components.push_back(binOp.getRHS());
1101 lvlExprComponents[pos] = components;
1103 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
1104 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1105 "expected floordiv before mod");
1110 assert(
false &&
"expected floordiv or mod");
1120 for (
auto &components : lvlExprComponents) {
1121 assert(components.second.size() == 3 &&
1122 "expected 3 components to build lvlExprs");
1127 lvlExprs.push_back(addOp);
1134 "expected dimToLvl to be block sparsity for calling getBlockSize");
1137 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1139 blockSize.push_back(
1140 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
1143 blockSize.push_back(0);
1152 std::map<unsigned, int64_t> coeffientMap;
1153 bool hasBlock =
false;
1155 if (
auto binOp = dyn_cast<AffineBinaryOpExpr>(
result)) {
1157 auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
1158 auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
1159 if (!dimOp || !conOp || conOp.getValue() <= 0)
1162 auto pos = dimOp.getPosition();
1165 auto [it,
inserted] = coeffientMap.try_emplace(pos);
1169 it->second = conOp.getValue();
1172 auto it = coeffientMap.find(pos);
1173 if (it == coeffientMap.end())
1176 if (conOp.getValue() != it->second)
1182 }
else if (
auto dimOp = dyn_cast<AffineDimExpr>(
result)) {
1183 auto pos = dimOp.getPosition();
1185 if (!coeffientMap.try_emplace(pos, 0).second)
1195 auto hasNonIdentityMap = [](
Value v) {
1200 return llvm::any_of(op->
getOperands(), hasNonIdentityMap) ||
1201 llvm::any_of(op->
getResults(), hasNonIdentityMap);
1206 assert(enc.isPermutation() &&
"Non permutation map not supported");
1207 if (
const auto dimToLvl = enc.getDimToLvl())
1215 assert(enc.isPermutation() &&
"Non permutation map not supported");
1216 if (
const auto lvlToDim = enc.getLvlToDim())
1226static SparseTensorEncodingAttr
1229 for (
auto lt : enc.getLvlTypes())
1232 return SparseTensorEncodingAttr::get(
1233 enc.getContext(), lts,
1243 enc.getDimSlices());
1247StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1254 SparseTensorEncodingAttr encoding) {
1273 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1275 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1277 "redundant level argument for querying value memory size");
1280 const auto enc = md.getType().getEncoding();
1281 const Level lvlRank = enc.getLvlRank();
1283 if (mdKind == StorageSpecifierKind::DimOffset ||
1284 mdKind == StorageSpecifierKind::DimStride)
1286 return op->
emitError(
"requested slice data on non-slice tensor");
1288 if (mdKind != StorageSpecifierKind::ValMemSize) {
1290 return op->
emitError(
"missing level argument");
1292 const Level l = lvl.value();
1294 return op->
emitError(
"requested level is out of bounds");
1296 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1298 "requested position memory size on a singleton level");
1314 llvm_unreachable(
"Unrecognizable FieldKind");
1319 RankedTensorType valTp,
1322 return op->
emitError(
"the sparse-tensor must have static shape");
1324 return op->
emitError(
"the sparse-tensor must have an encoding attribute");
1330 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1332 unsigned expCOORank = stt.
getLvlRank() - cooStartLvl;
1333 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1334 return op->
emitError(
"input/output trailing COO level-ranks don't match");
1341 return op->
emitError(
"inconsistent number of fields between input/output");
1344 bool misMatch =
false;
1351 Type inputTp =
nullptr;
1355 assert(fid == idx && stt.
getLvlType(lvl) == lt);
1356 inputTp = lvlTps[idx++];
1359 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1361 if (inpElemTp != expElemTp) {
1369 return op->
emitError(
"input/output element-types don't match");
1373LogicalResult AssembleOp::verify() {
1374 RankedTensorType valuesTp = getValues().getType();
1375 const auto lvlsTp = getLevels().getTypes();
1380LogicalResult DisassembleOp::verify() {
1382 return emitError(
"output values and return value type mismatch");
1384 for (
auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1385 if (ot.getType() != rt.getType())
1386 return emitError(
"output levels and return levels type mismatch");
1388 RankedTensorType valuesTp = getRetValues().getType();
1389 const auto lvlsTp = getRetLevels().getTypes();
1394LogicalResult ConvertOp::verify() {
1395 RankedTensorType tp1 = getSource().getType();
1396 RankedTensorType tp2 = getDest().getType();
1397 if (tp1.getRank() != tp2.getRank())
1398 return emitError(
"unexpected conversion mismatch in rank");
1400 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1401 if (dstEnc && dstEnc.isSlice())
1402 return emitError(
"cannot convert to a sparse tensor slice");
1404 auto shape1 = tp1.getShape();
1405 auto shape2 = tp2.getShape();
1409 for (
Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1410 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1411 return emitError(
"unexpected conversion mismatch in dimension ") << d;
1415OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1421bool ConvertOp::needsExtraSort() {
1440 if (
auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1441 if (isa<SparseElementsAttr>(constOp.getValue()))
1447LogicalResult CrdTranslateOp::verify() {
1448 uint64_t inRank = getEncoder().getLvlRank();
1449 uint64_t outRank = getEncoder().getDimRank();
1451 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1452 std::swap(inRank, outRank);
1454 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1455 return emitError(
"Coordinate rank mismatch with encoding");
1460LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1461 SmallVectorImpl<OpFoldResult> &results) {
1462 if (getEncoder().isIdentity()) {
1463 results.assign(getInCrds().begin(), getInCrds().end());
1467 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1468 ? getEncoder().getDimToLvl()
1469 : getEncoder().getLvlToDim();
1471 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1476 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1477 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1483 bool oppositeDir = def.getDirection() != getDirection();
1485 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1486 bool sameCount = def.getNumResults() == getInCrds().size();
1487 if (!oppositeDir || !sameOracle || !sameCount)
1492 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1493 [](
auto valuePair) {
1494 auto [
lhs,
rhs] = valuePair;
1502 results.append(def.getInCrds().begin(), def.getInCrds().end());
1506void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1509 return build(builder, state, source, val);
1512LogicalResult LvlOp::verify() {
1513 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1515 if (
static_cast<uint64_t
>(lvl.value()) >= stt.
getLvlRank())
1517 "Level index exceeds the rank of the input sparse tensor");
1522std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1532 cast<RankedTensorType>(getSource().
getType()).getRank());
1536OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1537 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1541 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1551 auto getIndexAttr = [
this](int64_t lvlSz) {
1552 return IntegerAttr::get(IndexType::get(
getContext()), APInt(64, lvlSz));
1556 if (ShapedType::isStatic(lvlShape[lvl]))
1557 return getIndexAttr(lvlShape[lvl]);
1562void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1563 SparseTensorEncodingAttr dstEnc, Value source) {
1565 SmallVector<int64_t> srcLvlShape = srcStt.
getLvlShape();
1566 SmallVector<int64_t> dstDimShape =
1567 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1569 RankedTensorType::get(dstDimShape, srcStt.
getElementType(), dstEnc);
1570 return build(odsBuilder, odsState, dstTp, source);
1573LogicalResult ReinterpretMapOp::verify() {
1576 ArrayRef<LevelType> srcLvlTps = srcStt.
getLvlTypes();
1577 ArrayRef<LevelType> dstLvlTps = dstStt.
getLvlTypes();
1579 if (srcLvlTps.size() != dstLvlTps.size())
1580 return emitError(
"Level rank mismatch between source/dest tensors");
1582 for (
auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1583 if (srcLvlTp != dstLvlTp)
1584 return emitError(
"Level type mismatch between source/dest tensors");
1588 return emitError(
"Crd/Pos width mismatch between source/dest tensors");
1592 return emitError(
"Element type mismatch between source/dest tensors");
1594 SmallVector<Size> srcLvlShape = srcStt.
getLvlShape();
1595 SmallVector<Size> dstLvlShape = dstStt.
getLvlShape();
1596 for (
auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1597 if (srcLvlSz != dstLvlSz) {
1601 return emitError(
"Level size mismatch between source/dest tensors");
1608OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1612 if (
auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1614 if (def.getSource().getType() == getDest().
getType())
1615 return def.getSource();
1620template <
typename ToBufferOp>
1624 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1626 Type elemTp =
nullptr;
1627 bool withStride =
false;
1628 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1630 }
else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1631 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1633 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1635 }
else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1639 assert(elemTp &&
"unhandled operation.");
1641 bufShape.push_back(ShapedType::kDynamic);
1643 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1645 {ShapedType::kDynamic})
1646 : StridedLayoutAttr();
1647 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1651LogicalResult ToPositionsOp::verify() {
1654 return emitError(
"requested level is out of bounds");
1656 return emitError(
"unexpected type for positions");
1661ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1663 PropertyRef prop, RegionRange region,
1664 SmallVectorImpl<mlir::Type> &ret) {
1668LogicalResult ToCoordinatesOp::verify() {
1671 return emitError(
"requested level is out of bounds");
1673 return emitError(
"unexpected type for coordinates");
1678ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1680 PropertyRef prop, RegionRange region,
1681 SmallVectorImpl<mlir::Type> &ret) {
1685LogicalResult ToCoordinatesBufferOp::verify() {
1688 return emitError(
"expected sparse tensor with a COO region");
1692LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1693 MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops,
1694 DictionaryAttr attr, PropertyRef prop, RegionRange region,
1695 SmallVectorImpl<mlir::Type> &ret) {
1700LogicalResult ToValuesOp::verify() {
1704 return emitError(
"unexpected mismatch in element types");
1708LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1709 std::optional<Location> loc,
1711 PropertyRef prop, RegionRange region,
1712 SmallVectorImpl<mlir::Type> &ret) {
1716LogicalResult ToSliceOffsetOp::verify() {
1717 auto rank =
getSlice().getType().getRank();
1718 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1719 return emitError(
"requested dimension out of bound");
1723LogicalResult ToSliceStrideOp::verify() {
1724 auto rank =
getSlice().getType().getRank();
1725 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1726 return emitError(
"requested dimension out of bound");
1730LogicalResult GetStorageSpecifierOp::verify() {
1732 getSpecifier(), getOperation());
1735template <
typename SpecifierOp>
1737 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1740OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1741 const StorageSpecifierKind kind = getSpecifierKind();
1742 const auto lvl = getLevel();
1744 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1745 return op.getValue();
1749LogicalResult SetStorageSpecifierOp::verify() {
1751 getSpecifier(), getOperation());
1756 const char *regionName,
1759 unsigned expectedNum = inputTypes.size();
1760 if (numArgs != expectedNum)
1761 return op->emitError() << regionName <<
" region must have exactly "
1762 << expectedNum <<
" arguments";
1764 for (
unsigned i = 0; i < numArgs; i++) {
1766 if (typ != inputTypes[i])
1767 return op->emitError() << regionName <<
" region argument " << (i + 1)
1768 <<
" type mismatch";
1772 return op->emitError() << regionName
1773 <<
" region must end with a terminator";
1776 YieldOp yield = dyn_cast<YieldOp>(term);
1778 return op->emitError() << regionName
1779 <<
" region must end with sparse_tensor.yield";
1780 if (!yield.hasSingleResult() ||
1781 yield.getSingleResult().getType() != outputType)
1782 return op->emitError() << regionName <<
" region yield type mismatch";
1787LogicalResult BinaryOp::verify() {
1788 NamedAttrList attrs = (*this)->getAttrs();
1789 Type leftType = getX().getType();
1790 Type rightType = getY().getType();
1791 Type outputType = getOutput().getType();
1792 Region &overlap = getOverlapRegion();
1793 Region &left = getLeftRegion();
1794 Region &right = getRightRegion();
1798 if (!overlap.
empty()) {
1800 TypeRange{leftType, rightType}, outputType)))
1803 if (!left.
empty()) {
1807 }
else if (getLeftIdentity()) {
1808 if (leftType != outputType)
1809 return emitError(
"left=identity requires first argument to have the same "
1810 "type as the output");
1812 if (!right.
empty()) {
1816 }
else if (getRightIdentity()) {
1817 if (rightType != outputType)
1818 return emitError(
"right=identity requires second argument to have the "
1819 "same type as the output");
1824LogicalResult UnaryOp::verify() {
1825 Type inputType = getX().getType();
1826 Type outputType = getOutput().getType();
1830 Region &present = getPresentRegion();
1831 if (!present.
empty()) {
1836 Region &absent = getAbsentRegion();
1837 if (!absent.
empty()) {
1843 Block *parent = getOperation()->getBlock();
1845 cast<YieldOp>(absentBlock->
getTerminator()).getSingleResult();
1846 if (
auto arg = dyn_cast<BlockArgument>(absentVal)) {
1847 if (arg.getOwner() == parent)
1848 return emitError(
"absent region cannot yield linalg argument");
1850 if (!isa<arith::ConstantOp>(def) &&
1851 (def->getBlock() == absentBlock || def->getBlock() == parent))
1852 return emitError(
"absent region cannot yield locally computed value");
1858bool ConcatenateOp::needsExtraSort() {
1863 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1870 bool directLowerable =
1871 allSameOrdered && getDimension() == 0 && dstStt.
isIdentity();
1872 return !directLowerable;
1875LogicalResult ConcatenateOp::verify() {
1877 const Dimension concatDim = getDimension();
1878 const Dimension dimRank = dstTp.getDimRank();
1880 if (getInputs().size() <= 1)
1881 return emitError(
"Need at least two tensors to concatenate.");
1883 if (concatDim >= dimRank)
1885 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1886 concatDim, dimRank));
1888 for (
const auto &it : llvm::enumerate(getInputs())) {
1889 const auto i = it.index();
1891 if (srcTp.hasDynamicDimShape())
1892 return emitError(llvm::formatv(
"Input tensor ${0} has dynamic shape", i));
1893 const Dimension srcDimRank = srcTp.getDimRank();
1894 if (srcDimRank != dimRank)
1896 llvm::formatv(
"Input tensor ${0} has a different rank (rank={1}) "
1897 "from the output tensor (rank={2}).",
1898 i, srcDimRank, dimRank));
1901 for (
Dimension d = 0; d < dimRank; d++) {
1902 const Size dstSh = dstTp.getDimShape()[d];
1903 if (d == concatDim) {
1904 if (ShapedType::isStatic(dstSh)) {
1909 for (
const auto src : getInputs())
1915 "The concatenation dimension of the output tensor should be the "
1916 "sum of all the concatenation dimensions of the input tensors.");
1920 for (
const auto src : getInputs()) {
1922 if (ShapedType::isStatic(prev) && sh != prev)
1923 return emitError(
"All dimensions (expect for the concatenating one) "
1924 "should be equal.");
1933void PushBackOp::build(OpBuilder &builder, OperationState &
result,
1934 Value curSize, Value inBuffer, Value value) {
1935 build(builder,
result, curSize, inBuffer, value, Value());
1938LogicalResult PushBackOp::verify() {
1939 if (Value n =
getN()) {
1941 if (nValue && nValue.value() < 1)
1947LogicalResult CompressOp::verify() {
1949 if (stt.
getLvlRank() != 1 +
static_cast<Level>(getLvlCoords().size()))
1950 return emitOpError(
"incorrect number of coordinates");
1954void ForeachOp::build(
1955 OpBuilder &builder, OperationState &
result, Value tensor,
1959 build(builder,
result, initArgs.
getTypes(), tensor, initArgs, order);
1967 SmallVector<Type> blockArgTypes(dimRank, builder.
getIndexType());
1971 blockArgTypes.append(initArgs.
getTypes().begin(), initArgs.
getTypes().end());
1973 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.
getLoc());
1975 OpBuilder::InsertionGuard guard(builder);
1976 auto ®ion = *
result.regions.front();
1978 builder.
createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
1979 bodyBuilder(builder,
result.location,
1985LogicalResult ForeachOp::verify() {
1987 const Dimension dimRank = t.getDimRank();
1988 const auto args = getBody()->getArguments();
1990 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1991 return emitError(
"Level traverse order does not match tensor's level rank");
1993 if (dimRank + 1 + getInitArgs().size() != args.size())
1994 return emitError(
"Unmatched number of arguments in the block");
1996 if (getNumResults() != getInitArgs().size())
1997 return emitError(
"Mismatch in number of init arguments and results");
1999 if (getResultTypes() != getInitArgs().getTypes())
2000 return emitError(
"Mismatch in types of init arguments and results");
2003 auto yield = cast<YieldOp>(getBody()->getTerminator());
2004 if (yield.getNumOperands() != getNumResults() ||
2005 yield.getOperands().getTypes() != getResultTypes())
2006 return emitError(
"Mismatch in types of yield values and results");
2008 const auto iTp = IndexType::get(
getContext());
2012 llvm::formatv(
"Expecting Index type for argument at index {0}", d));
2014 const auto elemTp = t.getElementType();
2015 const auto valueTp = args[dimRank].getType();
2016 if (elemTp != valueTp)
2018 llvm::formatv(
"Unmatched element type between input tensor and "
2019 "block argument, expected:{0}, got: {1}",
2024OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2027 return getInputCoo();
2032LogicalResult ReorderCOOOp::verify() {
2037 return emitError(
"Expected COO sparse tensors only");
2040 return emitError(
"Unmatched dim2lvl map between input and result COO");
2045 return emitError(
"Unmatched storage format between input and result COO");
2050LogicalResult ReduceOp::verify() {
2051 Type inputType = getX().getType();
2052 Region &formula = getRegion();
2054 TypeRange{inputType, inputType}, inputType);
2057LogicalResult SelectOp::verify() {
2059 Type inputType = getX().getType();
2060 Type boolType =
b.getI1Type();
2061 Region &formula = getRegion();
2066LogicalResult SortOp::verify() {
2067 AffineMap xPerm = getPermMap();
2070 return emitError(llvm::formatv(
"Expected rank(perm_map) > 1, got {0}", nx));
2074 llvm::formatv(
"Expected a permutation map, got {0}", xPerm));
2083 const auto checkDim = [&](Value v,
Size minSize,
2084 const char *message) -> LogicalResult {
2086 if (ShapedType::isStatic(sh) && sh < minSize)
2088 llvm::formatv(
"{0} got {1} < {2}", message, sh, minSize));
2091 uint64_t n = cn.value();
2093 if (
auto nyAttr = getNyAttr())
2094 ny = nyAttr.getInt();
2095 if (
failed(checkDim(getXy(), n * (nx + ny),
2096 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2098 for (Value opnd : getYs())
2099 if (
failed(checkDim(opnd, n,
"Expected dimension(y) >= n")))
2109IterSpaceType IteratorType::getIterSpaceType()
const {
2110 return IterSpaceType::get(
getContext(), getEncoding(), getLoLvl(),
2114IteratorType IterSpaceType::getIteratorType()
const {
2115 return IteratorType::get(
getContext(), getEncoding(), getLoLvl(), getHiLvl());
2134 "expect larger level upper bound than lower bound");
2142 IntegerAttr &lvlHiAttr) {
2159 p << lo <<
" to " << hi;
2165 IntegerAttr lvlHi) {
2166 unsigned lo = lvlLo.getValue().getZExtValue();
2167 unsigned hi = lvlHi.getValue().getZExtValue();
2178 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2181 ParseResult crdList =
2186 definedSet.
set(cnt);
2194 "parsed more value than expected.");
2196 if (failed(crdList)) {
2199 "expecting SSA value or \"_\" for level coordinates");
2201 assert(definedArgs.size() == definedSet.
count());
2208 if (definedSet.
empty())
2211 for (
unsigned i = 0; i < size; i++) {
2212 if (definedSet[i]) {
2213 p << blocksArgs.front();
2214 blocksArgs = blocksArgs.drop_front();
2221 assert(blocksArgs.empty());
2234 for (
auto &coord : coords)
2255 if (iterators.size() != spaces.size())
2258 "mismatch in number of sparse iterators and sparse spaces");
2263 size_t numCrds = coords.size();
2271 blockArgs.append(coords);
2277 if (iterSpaceTps.size() != spaces.size())
2279 "mismatch in number of iteration space operands "
2280 "and iteration space types");
2282 for (
auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
2283 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2286 "expected sparse_tensor.iter_space type for "
2287 "iteration space operands");
2288 it.type = spaceTp.getIteratorType();
2303 if (args.size() != initArgs.size() || args.size() != state.
types.size()) {
2306 "mismatch in number of iteration arguments and return values");
2309 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.
types)) {
2331 size_t numCrds = coords.size();
2339 blockArgs.append(coords);
2347 if (iterSpaceTps.size() != spaces.size())
2349 "mismatch in number of iteration space operands "
2350 "and iteration space types");
2365 if (args.size() != initArgs.size() || args.size() != state.
types.size()) {
2368 "mismatch in number of iteration arguments and return values");
2371 for (
auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.
types)) {
2380LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2381 MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops,
2382 DictionaryAttr attr, PropertyRef prop, RegionRange region,
2383 SmallVectorImpl<mlir::Type> &ret) {
2385 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2387 ret.push_back(IterSpaceType::get(ctx, stt.
getEncoding(), adaptor.getLoLvl(),
2388 adaptor.getHiLvl()));
2392LogicalResult ExtractIterSpaceOp::verify() {
2393 if (getLoLvl() >= getHiLvl())
2394 return emitOpError(
"expected smaller level low than level high");
2397 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2399 "parent iterator should be specified iff level lower bound equals 0");
2403 IterSpaceType spaceTp = getExtractedSpace().getType();
2404 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2406 "mismatch in parent iterator encoding and iteration space encoding.");
2408 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2409 return emitOpError(
"parent iterator should be used to extract an "
2410 "iteration space from a consecutive level.");
2416LogicalResult ExtractValOp::verify() {
2418 auto itTp = getIterator().getType();
2421 return emitOpError(
"mismatch in tensor encoding and iterator encoding.");
2424 return emitOpError(
"must use last-level iterator to extract values. ");
2435 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2436 for (
unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2437 if (
auto crd = iterateOp.getLvlCrd(i)) {
2438 if (crd->getUsers().empty())
2439 toRemove.set(crd->getArgNumber());
2446 if (toRemove.none())
2450 iterateOp.setCrdUsedLvls(newUsedLvls);
2451 iterateOp.getBody()->eraseArguments(toRemove);
2457void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2458 mlir::MLIRContext *context) {
2459 results.
add<RemoveUnusedLvlCrds>(context);
2462void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2464 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.
getType()).getSpaceDim();
2467 return build(builder, odsState, iterSpace, initArgs, set);
2470void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2473 OpBuilder::InsertionGuard guard(builder);
2479 Region *bodyRegion = odsState.
addRegion();
2484 for (Value v : initArgs)
2488 for (
unsigned i = 0, e = crdUsedLvls.
count(); i < e; i++)
2493 llvm::cast<IterSpaceType>(iterSpace.
getType()).getIteratorType(),
2497ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &
result) {
2498 OpAsmParser::Argument iterator;
2499 OpAsmParser::UnresolvedOperand iterSpace;
2501 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2504 if (iters.size() != 1)
2506 "expected only one iterator/iteration space");
2508 iterArgs.append(iters);
2509 Region *body =
result.addRegion();
2529 StringRef prefix =
"") {
2530 assert(blocksArgs.size() == initializers.size() &&
2531 "expected same length of arguments and initializers");
2532 if (initializers.empty())
2536 llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](
auto it) {
2537 p << std::get<0>(it) <<
" = " << std::get<1>(it);
2542template <
typename SparseLoopOp>
2544 if (op.getInitArgs().size() != op.getNumResults()) {
2545 return op.emitOpError(
2546 "mismatch in number of loop-carried values and defined values");
2548 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2549 return op.emitOpError(
"required out-of-bound coordinates");
2557void IterateOp::print(OpAsmPrinter &p) {
2558 p <<
" " << getIterator() <<
" in " << getIterSpace();
2559 if (!getCrdUsedLvls().empty()) {
2566 p <<
" : " << getIterSpace().getType() <<
" ";
2567 if (!getInitArgs().empty())
2572 !getInitArgs().empty());
2575LogicalResult IterateOp::verifyRegions() {
2576 if (getIterator().
getType() != getIterSpace().
getType().getIteratorType())
2577 return emitOpError(
"mismatch in iterator and iteration space type");
2578 if (getNumRegionIterArgs() != getNumResults())
2580 "mismatch in number of basic block args and defined values");
2582 auto initArgs = getInitArgs();
2583 auto iterArgs = getRegionIterArgs();
2584 auto yieldVals = getYieldedValues();
2585 auto opResults = getResults();
2586 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2587 opResults.size()})) {
2588 return emitOpError() <<
"number mismatch between iter args and results.";
2591 for (
auto [i, init, iter, yield, ret] :
2592 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2593 if (init.getType() != ret.getType())
2594 return emitOpError() <<
"types mismatch between " << i
2595 <<
"th iter operand and defined value";
2596 if (iter.getType() != ret.getType())
2597 return emitOpError() <<
"types mismatch between " << i
2598 <<
"th iter region arg and defined value";
2599 if (yield.getType() != ret.getType())
2600 return emitOpError() <<
"types mismatch between " << i
2601 <<
"th yield value and defined value";
2608SmallVector<Region *> IterateOp::getLoopRegions() {
return {&getRegion()}; }
2610MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2611 return getInitArgsMutable();
2615 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2618std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2619 return cast<sparse_tensor::YieldOp>(
2620 getRegion().getBlocks().front().getTerminator())
2621 .getResultsMutable();
2624std::optional<ResultRange> IterateOp::getLoopResults() {
return getResults(); }
2626OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2627 return getInitArgs();
2630void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2631 SmallVectorImpl<RegionSuccessor> ®ions) {
2634 regions.push_back(RegionSuccessor(&getRegion()));
2639ValueRange IterateOp::getSuccessorInputs(RegionSuccessor successor) {
2644void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2646 unsigned numCases) {
2648 cast<IterSpaceType>(iterSpaces.front().
getType()).getSpaceDim();
2655 SmallVector<int64_t> caseBits(numCases, 0);
2657 return CoIterateOp::build(builder, odsState, initArgs.
getTypes(), iterSpaces,
2658 initArgs, set, cases,
2662ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &
result) {
2664 SmallVector<Value> spaces;
2667 SmallVector<OpAsmParser::Argument> blockArgs;
2671 result.addAttribute(
"operandSegmentSizes",
2673 {static_cast<int32_t>(spaces.size()),
2674 static_cast<int32_t>(result.types.size())}));
2676 SmallVector<Attribute> cases;
2680 SmallVector<OpAsmParser::Argument> definedIts;
2687 for (
auto [i, definedIdx] : llvm::enumerate(definedItSet.
bits())) {
2689 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].
getType());
2690 definedIts[i].type = spaceTp.getIteratorType();
2692 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2693 Region *body =
result.addRegion();
2697 CoIterateOp::ensureTerminator(*body, parser.
getBuilder(),
result.location);
2709void CoIterateOp::print(OpAsmPrinter &p) {
2711 llvm::interleaveComma(getIterSpaces(), p, [&](
auto s) { p << s; });
2714 if (!getCrdUsedLvls().empty()) {
2722 p <<
" : (" << getIterSpaces().getTypes() <<
")";
2723 if (!getInitArgs().empty())
2724 p.printArrowTypeList(getInitArgs().getTypes());
2726 for (
unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2730 getRegionDefinedSpace(idx));
2732 p.printRegion(getRegion(idx),
false,
2733 !getInitArgs().empty());
2737ValueRange CoIterateOp::getYieldedValues(
unsigned regionIdx) {
2738 return cast<sparse_tensor::YieldOp>(
2739 getRegion(regionIdx).getBlocks().front().getTerminator())
2743LogicalResult CoIterateOp::verifyRegions() {
2744 for (
unsigned r = 0, e = getNumRegions(); r < e; r++) {
2745 if (getNumRegionIterArgs() != getNumResults())
2747 "mismatch in number of basic block args and defined values");
2749 auto initArgs = getInitArgs();
2750 auto iterArgs = getRegionIterArgs(r);
2751 auto yieldVals = getYieldedValues(r);
2752 auto opResults = getResults();
2753 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2754 opResults.size()})) {
2756 <<
"number mismatch between iter args and results on " << r
2760 for (
auto [i, init, iter, yield, ret] :
2761 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2762 if (init.getType() != ret.getType())
2764 <<
"types mismatch between " << i
2765 <<
"th iter operand and defined value on " << r <<
"th region";
2766 if (iter.getType() != ret.getType())
2767 return emitOpError() <<
"types mismatch between " << i
2768 <<
"th iter region arg and defined value on " << r
2770 if (yield.getType() != ret.getType())
2772 <<
"types mismatch between " << i
2773 <<
"th yield value and defined value on " << r <<
"th region";
2777 auto cases = getRegionDefinedSpaces();
2778 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2779 if (set.size() != getNumRegions())
2785SmallVector<Region *> CoIterateOp::getSubCasesOf(
unsigned regionIdx) {
2786 SmallVector<Region *> ret;
2787 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2788 for (Region &r : getCaseRegions())
2789 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2801Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2802 Attribute value, Type type,
2804 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2809void SparseTensorDialect::initialize() {
2811#define GET_ATTRDEF_LIST
2812#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2815#define GET_TYPEDEF_LIST
2816#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2820#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2822 declarePromisedInterfaces<
2823 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2824 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2825 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2828#define GET_OP_CLASSES
2829#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2831#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
static bool isPermutation(const std::vector< PermutationTy > &permutation)
static Type getElementType(Type type)
Determine the element type of type.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static bool isUnique(It begin, It end)
static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseOptionalStaticSlice(int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc)
We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compresse...
static ParseResult parseUsedCoordList(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static LogicalResult isMatchingWidth(Value mem, unsigned width)
static constexpr bool acceptBitWidth(unsigned bitWidth)
static mlir::ParseResult parseLevelRange(mlir::AsmParser &, mlir::sparse_tensor::Level &, mlir::sparse_tensor::Level &)
Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static LogicalResult lvlIsInBounds(Level lvl, Value tensor)
static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static constexpr FieldIndex kDataFieldStartingIdx
static constexpr Level kInvalidLevel
static LogicalResult verifySparseLoopOp(SparseLoopOp op)
static constexpr Level kInvalidFieldIndex
static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, mlir::sparse_tensor::Level)
Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind)
static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op)
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, PropertyRef prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
static ParseResult parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static SmallVector< Size > getSparseFieldShape(const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalDefinedList(OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to an...
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
static ParseResult parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static LogicalResult verifySparsifierGetterSetter(StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
void print(raw_ostream &os) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseLBrace()=0
Parse a { token.
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseQuestion()=0
Parse a '?' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printArrowTypeList(TypeRange &&types)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
unsigned getNumArguments()
BlockArgument getArgument(unsigned i)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
iterator_range< const_set_bits_iterator > bits() const
I64BitSet & set(unsigned i)
A wrapper around RankedTensorType, which has three goals:
bool isSingletonLvl(Level l) const
SmallVector< Size > getBatchLvlShape() const
Returns the batched level-shape.
MLIRContext * getContext() const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
unsigned getCrdWidth() const
Returns the coordinate-overhead bitwidth, defaulting to zero.
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
AffineMap getLvlToDim() const
Returns the lvlToDiml mapping (or the null-map for the identity).
Attribute getImplicitVal() const
Returns the implicit value, defaulting to null Attribute for 0.
bool isAllDense() const
Returns true for tensors where every level is dense.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
SmallVector< Size > getLvlShape() const
Returns the level-shape.
bool isCompressedLvl(Level l) const
bool hasStaticDimShape() const
Returns true if no dimension has dynamic size.
Level getLvlRank() const
Returns the level-rank.
ArrayRef< LevelType > getLvlTypes() const
unsigned getPosWidth() const
Returns the position-overhead bitwidth, defaulting to zero.
RankedTensorType getCOOType(bool ordered) const
Returns [un]ordered COO type for this sparse tensor type.
SparseTensorEncodingAttr getEncoding() const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Attribute getExplicitVal() const
Returns the explicit value, defaulting to null Attribute for unset.
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Provides methods to access fields of a sparse tensor with the given encoding.
unsigned getNumDataFields() const
Gets the total number of data fields (coordinate arrays, position arrays, and a value array) for the ...
unsigned getNumFields() const
Gets the total number of fields for the given sparse tensor encoding.
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
std::pair< FieldIndex, unsigned > getFieldIndexAndStride(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Parses the Sparse Tensor Encoding Attribute (STEA).
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
bool isUniqueLT(LevelType lt)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
std::optional< LevelType > buildLevelType(LevelFormat lf, const std::vector< LevelPropNonDefault > &properties, uint64_t n=0, uint64_t m=0)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
bool isOrderedLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Dimension toDim(SparseTensorEncodingAttr enc, Level l)
Convenience method to translate the given level to the corresponding dimension.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
static llvm::hash_code hash_value(LevelType lt)
uint64_t getN(LevelType lt)
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Level toLvl(SparseTensorEncodingAttr enc, Dimension d)
Convenience method to translate the given dimension to the corresponding level.
bool isBlockSparsity(AffineMap dimToLvl)
Given the dimToLvl map, returns if it's block sparsity.
bool isDenseLT(LevelType lt)
uint64_t getM(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
bool isBatchLT(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context)
Returns the lvlToDim map for the given dimToLvl map specific to the block sparse cases.
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(IterateOp iterateOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
A simple structure that encodes a range of levels in the sparse tensors that forms a COO segment.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.
LevelType stripStorageIrrelevantProperties() const