32#include "llvm/ADT/SmallVectorExtras.h"
46 for (
const auto &vals : values)
47 llvm::append_range(
result, vals);
54 return memref::LoadOp::create(builder, loc, mem, idx);
61 val =
genCast(builder, loc, val,
62 cast<ShapedType>(mem.
getType()).getElementType());
63 memref::StoreOp::create(builder, loc, val, mem, idx);
75 scf::ForOp::create(builder, loc, lower, upper, one, fields);
76 for (
unsigned i = 0, e = fields.size(); i < e; i++)
77 fields[i] = forOp.getRegionIterArg(i);
91 auto pushBackOp = PushBackOp::create(
93 field,
genCast(builder, loc, value, etp), repeat);
97 pushBackOp.getNewSize());
106 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
117 linear = arith::MulIOp::create(builder, loc, linear, two);
130 linear = arith::MulIOp::create(builder, loc, linear, size);
135 std::nullopt, valZero, linear);
140 MemRefType memRefType,
Value sz,
142 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
143 Type elemType = memRefType.getElementType();
146 linalg::FillOp::create(builder, loc, fillValue, buffer);
156 dimSizesValues.clear();
157 dimSizesValues.reserve(dimRank);
160 dimSizesValues.push_back(ShapedType::isDynamic(sz)
179 Value posHeuristic, crdHeuristic, valHeuristic;
181 valHeuristic = lvlSizesValues[0];
182 for (
Level lvl = 1; lvl < lvlRank; lvl++)
183 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
184 lvlSizesValues[lvl]);
185 }
else if (sizeHint) {
188 crdHeuristic = arith::MulIOp::create(
189 builder, loc,
constantIndex(builder, loc, lvlRank), sizeHint);
191 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
193 crdHeuristic = sizeHint;
195 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
197 valHeuristic = sizeHint;
199 posHeuristic = crdHeuristic = valHeuristic =
206 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
209 assert(fields.size() == fIdx);
217 posHeuristic, enableInit);
221 crdHeuristic, enableInit);
225 valHeuristic, enableInit);
229 fields.push_back(field);
238 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
239 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
275 assert(lvl < lvlRank &&
"Level is out of bounds");
276 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
277 "Level-rank mismatch");
285 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
287 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
288 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
290 const Value crdStrideC =
293 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
295 const Value plast = arith::SubIOp::create(
296 builder, loc,
genCast(builder, loc, pstop, indexType), one);
298 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
300 types.push_back(boolType);
301 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt,
true);
306 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
308 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
309 genCast(builder, loc, crd, indexType),
311 scf::YieldOp::create(builder, loc, eq);
314 genStore(builder, loc, msz, positionsAtLvl, parentPos);
315 scf::YieldOp::create(builder, loc,
constantI1(builder, loc,
false));
322 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
324 types.push_back(indexType);
327 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p,
true);
334 scf::YieldOp::create(builder, loc, desc.
getFields());
339 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
340 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
344 if ((lvl + 1) < lvlRank)
348 scf::YieldOp::create(builder, loc, desc.
getFields());
354 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
355 desc.
setField(i, ifOp2.getResult(o++));
356 return ifOp2.getResult(o);
364 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
381 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
382 Value i = loop.getInductionVar();
383 Value oldv = loop.getRegionIterArg(0);
386 Value cond = arith::CmpIOp::create(
387 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
388 scf::IfOp ifOp = scf::IfOp::create(builder, loc,
TypeRange(posType),
391 genStore(builder, loc, oldv, posMemRef, i);
392 scf::YieldOp::create(builder, loc, oldv);
394 scf::YieldOp::create(builder, loc, newv);
396 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
409 auto memTp = llvm::cast<MemRefType>(mem.
getType());
413 if (memTp.getRank() > 1)
416 return memref::SubViewOp::create(
418 MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
432 for (
unsigned i = 0; i < batchLvls; i++)
435 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
436 ret.back().push_back(i);
448class SparseInsertGenerator
453 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
466 OpBuilder &builder, Location loc) {
467 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
468 const Level lvlRank = stt.getLvlRank();
470 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
471 MutSparseTensorDescriptor desc(stt, fields);
472 const SmallVector<Value> coords =
473 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
474 Value value = args.back();
477 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
478 const auto lt = stt.getLvlType(lvl);
489 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
492 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
498 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
505 Value size = desc.
getLvlSize(builder, loc, lvl);
506 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
507 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
511 if (!stt.isDenseLvl(lvlRank - 1))
512 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
513 std::nullopt, value);
519 std::string getMangledFuncName() {
522 constexpr const char kInsertFuncNamePrefix[] =
"_insert_";
523 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
524 SmallString<32> nameBuffer;
525 llvm::raw_svector_ostream nameOstream(nameBuffer);
526 nameOstream << kInsertFuncNamePrefix;
527 const Level lvlRank = stt.getLvlRank();
528 for (
Level l = 0; l < lvlRank; l++) {
532 lvlType.begin(), lvlType.end(),
533 [](
char c) { return c ==
'(' || c ==
','; },
'_');
534 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
535 nameOstream << lvlType <<
"_";
540 for (
const auto sz : stt.getDimShape())
541 nameOstream << sz <<
"_";
543 if (!stt.isIdentity())
544 nameOstream << stt.getDimToLvl() <<
"_";
545 nameOstream << stt.getElementType() <<
"_";
546 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
547 return nameOstream.str().str();
555class SparseReturnConverter :
public OpConversionPattern<func::ReturnOp> {
557 using OpConversionPattern::OpConversionPattern;
559 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override {
562 rewriter.replaceOpWithNewOp<func::ReturnOp>(
569class SparseCallConverter :
public OpConversionPattern<func::CallOp> {
572 using OpConversionPattern::OpConversionPattern;
574 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter)
const override {
576 Location loc = op.getLoc();
582 SmallVector<Type> finalRetTy;
583 if (
failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
588 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
591 SmallVector<SmallVector<Value>> packedResultVals;
594 unsigned retOffset = 0;
597 SmallVector<Type> sparseFlat;
598 for (
auto ret : op.getResults()) {
599 assert(retOffset < newCall.getNumResults());
600 auto retType = ret.getType();
601 if (
failed(typeConverter->convertType(retType, sparseFlat)))
602 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
605 assert(!sparseFlat.empty());
606 if (sparseFlat.size() > 1) {
607 auto flatSize = sparseFlat.size();
608 packedResultVals.emplace_back();
609 llvm::append_range(packedResultVals.back(),
610 newCall.getResults().slice(retOffset, flatSize));
611 retOffset += flatSize;
614 packedResultVals.emplace_back();
615 packedResultVals.back().push_back(newCall.getResult(retOffset));
621 assert(packedResultVals.size() == op.getNumResults());
622 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
628class SparseLvlOpConverter :
public OpConversionPattern<LvlOp> {
630 using OpConversionPattern::OpConversionPattern;
632 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 std::optional<int64_t> lvl = op.getConstantLvlIndex();
635 RankedTensorType srcType = op.getSource().getType();
640 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
642 rewriter.replaceOp(op, sz);
648struct SparseReorderCOOConverter :
public OpConversionPattern<ReorderCOOOp> {
649 using OpConversionPattern::OpConversionPattern;
651 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
652 ConversionPatternRewriter &rewriter)
const override {
653 Location loc = op.getLoc();
666 op.getInputCoo().getType());
677 SortOp::create(rewriter, loc, nnz, crd,
ValueRange{val}, id,
678 rewriter.getIndexAttr(0), op.getAlgorithm());
682 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
687template <
typename Op, StorageSpecifierKind kind>
688class SparseSliceGetterOpConverter :
public OpConversionPattern<Op> {
690 using OpConversionPattern<
Op>::OpConversionPattern;
691 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
694 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter)
const override {
698 op.getSlice().getType());
700 op.getDim().getZExtValue());
702 rewriter.replaceOp(op, v);
708class SparseCastConverter :
public OpConversionPattern<tensor::CastOp> {
710 using OpConversionPattern::OpConversionPattern;
712 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
713 ConversionPatternRewriter &rewriter)
const override {
717 if (!encDst || encDst != encSrc)
719 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
724class SparseReMapConverter :
public OpConversionPattern<ReinterpretMapOp> {
726 using OpConversionPattern::OpConversionPattern;
728 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
729 ConversionPatternRewriter &rewriter)
const override {
731 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
737class SparseTensorAllocConverter
738 :
public OpConversionPattern<bufferization::AllocTensorOp> {
740 using OpConversionPattern::OpConversionPattern;
741 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
742 MLIRContext *context,
bool enableInit)
743 : OpConversionPattern(typeConverter, context),
744 enableBufferInitialization(enableInit) {}
747 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter)
const override {
750 if (!resType.hasEncoding())
753 Location loc = op.getLoc();
757 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
758 SmallVector<Value> fields;
762 auto memrefTp = cast<MemRefType>(field.getType());
763 auto size = memref::DimOp::create(rewriter, loc, field, 0);
765 memref::AllocOp::create(rewriter, loc, memrefTp,
ValueRange{size});
766 memref::CopyOp::create(rewriter, loc, field, copied);
767 fields.push_back(copied);
772 rewriter.replaceOpWithMultiple(op, {fields});
776 if (!resType.isIdentity()) {
777 return rewriter.notifyMatchFailure(
778 op,
"try run --sparse-reinterpret-map before codegen");
781 SmallVector<Value> lvlSizesValues;
787 Value sizeHint = op.getSizeHint();
788 SmallVector<Value> fields;
790 sizeHint, lvlSizesValues, fields);
793 rewriter.replaceOpWithMultiple(op, {fields});
798 bool enableBufferInitialization;
802class SparseTensorEmptyConverter :
public OpConversionPattern<tensor::EmptyOp> {
804 using OpConversionPattern::OpConversionPattern;
805 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
806 MLIRContext *context,
bool enableInit)
807 : OpConversionPattern(typeConverter, context),
808 enableBufferInitialization(enableInit) {}
811 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter)
const override {
814 if (!resType.hasEncoding())
817 if (!resType.isIdentity()) {
818 return rewriter.notifyMatchFailure(
819 op,
"try run --sparse-reinterpret-map before codegen");
822 Location loc = op.getLoc();
824 SmallVector<Value> lvlSizesValues;
829 SmallVector<Value> fields;
831 sizeHint, lvlSizesValues, fields);
834 rewriter.replaceOpWithMultiple(op, {fields});
839 bool enableBufferInitialization;
843class SparseTensorDeallocConverter
844 :
public OpConversionPattern<bufferization::DeallocTensorOp> {
846 using OpConversionPattern::OpConversionPattern;
847 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
848 MLIRContext *context,
bool createDeallocs)
849 : OpConversionPattern(typeConverter, context),
850 createDeallocs(createDeallocs) {}
853 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter)
const override {
861 if (createDeallocs) {
863 Location loc = op.getLoc();
866 cast<RankedTensorType>(op.getTensor().getType()));
869 memref::DeallocOp::create(rewriter, loc, input);
871 rewriter.eraseOp(op);
876 const bool createDeallocs;
880class SparseTensorLoadConverter :
public OpConversionPattern<LoadOp> {
882 using OpConversionPattern::OpConversionPattern;
884 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
885 ConversionPatternRewriter &rewriter)
const override {
888 op.getTensor().getType());
890 if (op.getHasInserts())
893 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
899class SparseExpandConverter :
public OpConversionPattern<ExpandOp> {
901 using OpConversionPattern::OpConversionPattern;
903 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
904 ConversionPatternRewriter &rewriter)
const override {
907 Location loc = op->getLoc();
909 op.getTensor().getType());
911 Type eltType = srcType.getElementType();
912 Type boolType = rewriter.getIntegerType(1);
913 Type idxType = rewriter.getIndexType();
915 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
919 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
921 const auto genAlloc = [&](Type t) {
922 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
923 return memref::AllocOp::create(rewriter, loc, memTp,
ValueRange{sz});
928 Value values = genAlloc(eltType);
929 Value filled = genAlloc(boolType);
930 Value added = genAlloc(idxType);
937 linalg::FillOp::create(rewriter, loc,
940 linalg::FillOp::create(rewriter, loc,
944 assert(op.getNumResults() == 4);
945 rewriter.replaceOp(op, {values, filled, added, zero});
951class SparseCompressConverter :
public OpConversionPattern<CompressOp> {
953 using OpConversionPattern::OpConversionPattern;
955 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter)
const override {
957 Location loc = op->getLoc();
958 SmallVector<Value> fields;
960 op.getTensor().getType());
961 Value values = llvm::getSingleElement(adaptor.getValues());
962 Value filled = llvm::getSingleElement(adaptor.getFilled());
963 Value added = llvm::getSingleElement(adaptor.getAdded());
964 Value count = llvm::getSingleElement(adaptor.getCount());
966 Type eltType = dstType.getElementType();
970 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
971 SortOp::create(rewriter, loc, count, added,
ValueRange{},
972 rewriter.getMultiDimIdentityMap(1),
973 rewriter.getIndexAttr(0),
974 SparseTensorSortKind::HybridQuickSort);
991 Value i = loop.getInductionVar();
993 Value crd =
genLoad(rewriter, loc, added, i);
994 Value value =
genLoad(rewriter, loc, values, crd);
996 SmallVector<Type> flatSpTensorTps = llvm::map_to_vector(
997 desc.
getFields(), [](Value v) { return v.getType(); });
998 SmallVector<Value> flatLvlCoords =
flattenValues(adaptor.getLvlCoords());
999 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1000 params.push_back(crd);
1001 params.push_back(value);
1002 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1004 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1007 scf::YieldOp::create(rewriter, loc, insertRet);
1009 rewriter.setInsertionPointAfter(loop);
1011 Operation *parent =
getTop(op);
1012 rewriter.setInsertionPointAfter(parent);
1013 memref::DeallocOp::create(rewriter, loc, values);
1014 memref::DeallocOp::create(rewriter, loc, filled);
1015 memref::DeallocOp::create(rewriter, loc, added);
1017 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1023class SparseInsertConverter :
public OpConversionPattern<tensor::InsertOp> {
1025 using OpConversionPattern::OpConversionPattern;
1027 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter)
const override {
1030 if (!stt.hasEncoding())
1032 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1034 Location loc = op.getLoc();
1038 SmallVector<Value> params = llvm::to_vector(desc.
getFields());
1039 SmallVector<Value> flatIndices =
flattenValues(adaptor.getIndices());
1040 params.append(flatIndices.begin(), flatIndices.end());
1041 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1042 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1044 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1046 rewriter.replaceOpWithMultiple(op, {ret});
1052class SparseToPositionsConverter :
public OpConversionPattern<ToPositionsOp> {
1054 using OpAdaptor = ToPositionsOp::Adaptor;
1055 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1057 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1058 ConversionPatternRewriter &rewriter)
const override {
1062 Location loc = op.getLoc();
1063 Level lvl = op.getLevel();
1065 op.getTensor().getType());
1068 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1074class SparseToCoordinatesConverter
1075 :
public OpConversionPattern<ToCoordinatesOp> {
1077 using OpAdaptor = ToCoordinatesOp::Adaptor;
1078 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1080 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1081 ConversionPatternRewriter &rewriter)
const override {
1085 Location loc = op.getLoc();
1086 Level lvl = op.getLevel();
1088 op.getTensor().getType());
1089 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1094 rewriter.replaceOp(op, mem);
1100class SparseToCoordinatesBufferConverter
1101 :
public OpConversionPattern<ToCoordinatesBufferOp> {
1103 using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
1104 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1106 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter)
const override {
1111 Location loc = op.getLoc();
1114 op.getTensor().getType());
1117 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1123class SparseToValuesConverter :
public OpConversionPattern<ToValuesOp> {
1125 using OpAdaptor = ToValuesOp::Adaptor;
1126 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1128 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1129 ConversionPatternRewriter &rewriter)
const override {
1133 Location loc = op.getLoc();
1135 op.getTensor().getType());
1138 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1144class SparseConvertConverter :
public OpConversionPattern<ConvertOp> {
1146 using OpConversionPattern::OpConversionPattern;
1148 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter)
const override {
1151 SparseTensorEncodingAttr encSrc =
1157 if (!encSrc || !encDst)
1162 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1166 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1171 Type retElemTp = op.getResult().getType().getElementType();
1172 Type srcElemTp = op.getSource().getType().getElementType();
1174 if (retElemTp == srcElemTp && encDst == encSrc) {
1175 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1187 Location loc = op.getLoc();
1189 op.getSource().getType());
1190 SmallVector<Value> fields;
1192 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1193 [&rewriter, &fields, srcDesc,
1195 LevelType ) ->
bool {
1197 if (fKind == SparseTensorFieldKind::StorageSpec) {
1198 fields.push_back(srcDesc.getSpecifier());
1201 Value srcMem = srcDesc.getMemRefField(fIdx);
1205 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1206 auto dstMem = memref::AllocOp::create(rewriter, loc,
1207 cast<MemRefType>(fTp), sz);
1208 if (fTp != srcMem.getType()) {
1211 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1212 constantIndex(rewriter, loc, 1),
1213 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1215 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1216 Value casted = genCast(builder, loc, v,
1217 dstMem.getType().getElementType());
1218 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1224 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1226 fields.push_back(dstMem);
1231 rewriter.replaceOpWithMultiple(op, {fields});
1236class SparseExtractSliceConverter
1237 :
public OpConversionPattern<tensor::ExtractSliceOp> {
1239 using OpConversionPattern::OpConversionPattern;
1241 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1242 ConversionPatternRewriter &rewriter)
const override {
1243 Location loc = op.getLoc();
1244 MLIRContext *ctx = op.getContext();
1248 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1250 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1252 SmallVector<Value> fields;
1254 op.getSource().getType());
1256 auto newSpec = StorageSpecifierInitOp::create(
1257 rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1262 for (
auto [idx, offset, size, stride] : llvm::enumerate(
1263 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1278 assert(srcEnc.isIdentity());
1288 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
1294class SparseNumberOfEntriesConverter
1295 :
public OpConversionPattern<NumberOfEntriesOp> {
1297 using OpConversionPattern::OpConversionPattern;
1299 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1300 ConversionPatternRewriter &rewriter)
const override {
1305 op.getTensor().getType());
1306 rewriter.replaceOp(op, desc.
getValMemSize(rewriter, op.getLoc()));
1311struct SparseAssembleOpConverter :
public OpConversionPattern<AssembleOp> {
1312 using OpConversionPattern::OpConversionPattern;
1314 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1315 ConversionPatternRewriter &rewriter)
const override {
1316 Location loc = op.getLoc();
1319 SmallVector<Value> fields;
1323 [&rewriter, &fields, &op, &stt,
1325 Level , LevelType lt) ->
bool {
1326 assert(fields.size() == fIdx);
1327 if (fKind == SparseTensorFieldKind::StorageSpec) {
1332 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1334 : op.getLevels()[fIdx];
1337 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1340 mem.getType(), stt.getBatchLvlRank());
1341 mem = memref::CastOp::create(
1342 rewriter, loc, fType,
1343 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1345 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1347 fields.push_back(mem);
1352 MutSparseTensorDescriptor desc(stt, fields);
1359 Level trailCOOStart = stt.getAoSCOOStart();
1360 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1362 for (
Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1363 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1366 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1367 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1370 if (lvl > trailCOOStart)
1374 LevelType lt = stt.getLvlType(lvl);
1376 if (lt.
isa<LevelFormat::Dense>()) {
1377 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1378 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1381 if (lt.
isa<LevelFormat::Batch>()) {
1391 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1392 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1396 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1402 SmallVector<Value> batched(stt.getBatchLvlRank(),
1404 batched.push_back(posBack);
1406 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1410 if (lvl == trailCOOStart) {
1411 Value cooSz = arith::MulIOp::create(
1412 rewriter, loc, memSize,
constantIndex(rewriter, loc, trailCOORank));
1420 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
1425struct SparseDisassembleOpConverter
1426 :
public OpConversionPattern<DisassembleOp> {
1427 using OpConversionPattern::OpConversionPattern;
1428 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1429 MLIRContext *context)
1430 : OpConversionPattern(typeConverter, context) {}
1433 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1434 ConversionPatternRewriter &rewriter)
const override {
1436 op.getTensor().getType());
1437 Location loc = op.getLoc();
1438 SmallVector<Value> retMem;
1439 SmallVector<Value> retLen;
1443 Level lvl, LevelType lt) ->
bool {
1444 if (fKind == SparseTensorFieldKind::StorageSpec)
1449 if (fKind == SparseTensorFieldKind::ValMemRef) {
1452 dst =
genToMemref(rewriter, loc, op.getOutValues());
1454 retMem.push_back(dst);
1455 Type valLenTp = op.getValLen().getType();
1458 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1459 fKind == SparseTensorFieldKind::CrdMemRef);
1461 sz = fKind == SparseTensorFieldKind::PosMemRef
1465 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1466 retMem.push_back(dst);
1468 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1471 Value flatOut = dst;
1472 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1475 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1479 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1484 SmallVector<Value> retValues =
1485 llvm::map_to_vector(retMem, [&rewriter, loc](Value v) -> Value {
1486 return bufferization::ToTensorOp::create(
1491 retValues.append(retLen.begin(), retLen.end());
1492 rewriter.replaceOp(op, retValues);
1497struct SparseNewConverter :
public OpConversionPattern<NewOp> {
1498 using OpConversionPattern::OpConversionPattern;
1500 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1501 ConversionPatternRewriter &rewriter)
const override {
1502 Location loc = op.getLoc();
1506 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1520 SmallVector<Value> dimSizesValues;
1521 Value dimSizesBuffer;
1522 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1523 dimSizesValues, dimSizesBuffer);
1526 const Type indexTp = rewriter.getIndexType();
1527 Value nse =
createFuncCall(rewriter, loc,
"getSparseTensorReaderNSE",
1528 {indexTp}, {reader}, EmitCInterface::Off)
1532 SmallVector<Value> lvlSizesValues;
1533 Value dim2lvlBuffer;
1534 Value lvl2dimBuffer;
1535 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1536 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1539 Value sizeHint = nse;
1540 SmallVector<Value> fields;
1542 lvlSizesValues, fields);
1545 MutSparseTensorDescriptor desc(dstTp, fields);
1548 const Type boolTp = rewriter.getIntegerType(1);
1549 const Type elemTp = dstTp.getElementType();
1550 const Type crdTp = dstTp.getCrdType();
1551 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1556 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1562 const Level lvlRank = dstTp.getLvlRank();
1563 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1564 Value kFalse =
constantI1(rewriter, loc,
false);
1565 Value notSorted = arith::CmpIOp::create(
1566 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1568 scf::IfOp::create(rewriter, loc, notSorted,
false);
1569 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1570 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1571 SortOp::create(rewriter, loc, nse, xs,
ValueRange{ys}, xPerm,
1572 rewriter.getIndexAttr(0),
1573 SparseTensorSortKind::HybridQuickSort);
1574 rewriter.setInsertionPointAfter(ifOp);
1580 const Type posTp = dstTp.getPosType();
1581 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1582 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1585 Value coordinatesSize = arith::MulIOp::create(
1593 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1594 EmitCInterface::Off);
1597 rewriter.replaceOpWithMultiple(op, {fields});
1602struct SparseHasRuntimeLibraryConverter
1603 :
public OpConversionPattern<HasRuntimeLibraryOp> {
1604 using OpConversionPattern::OpConversionPattern;
1606 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1607 ConversionPatternRewriter &rewriter)
const override {
1608 auto i1Type = rewriter.getI1Type();
1609 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1610 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1625 bool createSparseDeallocs,
bool enableBufferInitialization) {
1627 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1628 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1629 SparseCastConverter, SparseExtractSliceConverter,
1630 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1631 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1632 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1633 StorageSpecifierKind::DimOffset>,
1634 SparseSliceGetterOpConverter<ToSliceStrideOp,
1635 StorageSpecifierKind::DimStride>,
1636 SparseToPositionsConverter, SparseToCoordinatesConverter,
1637 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1638 SparseConvertConverter, SparseNewConverter,
1639 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1641 patterns.
add<SparseTensorDeallocConverter>(
1642 typeConverter, patterns.
getContext(), createSparseDeallocs);
1643 patterns.
add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1644 typeConverter, patterns.
getContext(), enableBufferInitialization);
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
bool isCompressedLT(LevelType lt)
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
unsigned FieldIndex
The type of field indices.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.