45 for (
const auto &vals : values)
46 llvm::append_range(
result, vals);
53 return memref::LoadOp::create(builder, loc, mem, idx);
60 val =
genCast(builder, loc, val,
61 cast<ShapedType>(mem.
getType()).getElementType());
62 memref::StoreOp::create(builder, loc, val, mem, idx);
74 scf::ForOp::create(builder, loc, lower, upper, one, fields);
75 for (
unsigned i = 0, e = fields.size(); i < e; i++)
76 fields[i] = forOp.getRegionIterArg(i);
90 auto pushBackOp = PushBackOp::create(
92 field,
genCast(builder, loc, value, etp), repeat);
96 pushBackOp.getNewSize());
105 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
116 linear = arith::MulIOp::create(builder, loc, linear, two);
129 linear = arith::MulIOp::create(builder, loc, linear, size);
134 std::nullopt, valZero, linear);
139 MemRefType memRefType,
Value sz,
141 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
142 Type elemType = memRefType.getElementType();
145 linalg::FillOp::create(builder, loc, fillValue, buffer);
155 dimSizesValues.clear();
156 dimSizesValues.reserve(dimRank);
159 dimSizesValues.push_back(ShapedType::isDynamic(sz)
178 Value posHeuristic, crdHeuristic, valHeuristic;
180 valHeuristic = lvlSizesValues[0];
181 for (
Level lvl = 1; lvl < lvlRank; lvl++)
182 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
183 lvlSizesValues[lvl]);
184 }
else if (sizeHint) {
187 crdHeuristic = arith::MulIOp::create(
188 builder, loc,
constantIndex(builder, loc, lvlRank), sizeHint);
190 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
192 crdHeuristic = sizeHint;
194 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
196 valHeuristic = sizeHint;
198 posHeuristic = crdHeuristic = valHeuristic =
205 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
208 assert(fields.size() == fIdx);
216 posHeuristic, enableInit);
220 crdHeuristic, enableInit);
224 valHeuristic, enableInit);
228 fields.push_back(field);
237 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
238 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
274 assert(lvl < lvlRank &&
"Level is out of bounds");
275 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
276 "Level-rank mismatch");
284 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
286 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
287 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
289 const Value crdStrideC =
292 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
294 const Value plast = arith::SubIOp::create(
295 builder, loc,
genCast(builder, loc, pstop, indexType), one);
297 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
299 types.push_back(boolType);
300 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt,
true);
305 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
307 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
308 genCast(builder, loc, crd, indexType),
310 scf::YieldOp::create(builder, loc, eq);
313 genStore(builder, loc, msz, positionsAtLvl, parentPos);
314 scf::YieldOp::create(builder, loc,
constantI1(builder, loc,
false));
321 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
323 types.push_back(indexType);
326 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p,
true);
333 scf::YieldOp::create(builder, loc, desc.
getFields());
338 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
339 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
343 if ((lvl + 1) < lvlRank)
347 scf::YieldOp::create(builder, loc, desc.
getFields());
353 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
354 desc.
setField(i, ifOp2.getResult(o++));
355 return ifOp2.getResult(o);
363 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
380 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
381 Value i = loop.getInductionVar();
382 Value oldv = loop.getRegionIterArg(0);
385 Value cond = arith::CmpIOp::create(
386 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
387 scf::IfOp ifOp = scf::IfOp::create(builder, loc,
TypeRange(posType),
390 genStore(builder, loc, oldv, posMemRef, i);
391 scf::YieldOp::create(builder, loc, oldv);
393 scf::YieldOp::create(builder, loc, newv);
395 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
408 auto memTp = llvm::cast<MemRefType>(mem.
getType());
412 if (memTp.getRank() > 1)
415 return memref::SubViewOp::create(
417 MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
431 for (
unsigned i = 0; i < batchLvls; i++)
434 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
435 ret.back().push_back(i);
447class SparseInsertGenerator
452 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
465 OpBuilder &builder, Location loc) {
466 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
467 const Level lvlRank = stt.getLvlRank();
469 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
470 MutSparseTensorDescriptor desc(stt, fields);
471 const SmallVector<Value> coords =
472 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
473 Value value = args.back();
476 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
477 const auto lt = stt.getLvlType(lvl);
488 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
491 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
497 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
504 Value size = desc.
getLvlSize(builder, loc, lvl);
505 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
506 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
510 if (!stt.isDenseLvl(lvlRank - 1))
511 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
512 std::nullopt, value);
518 std::string getMangledFuncName() {
521 constexpr const char kInsertFuncNamePrefix[] =
"_insert_";
522 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
523 SmallString<32> nameBuffer;
524 llvm::raw_svector_ostream nameOstream(nameBuffer);
525 nameOstream << kInsertFuncNamePrefix;
526 const Level lvlRank = stt.getLvlRank();
527 for (
Level l = 0; l < lvlRank; l++) {
531 lvlType.begin(), lvlType.end(),
532 [](
char c) { return c ==
'(' || c ==
','; },
'_');
533 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
534 nameOstream << lvlType <<
"_";
539 for (
const auto sz : stt.getDimShape())
540 nameOstream << sz <<
"_";
542 if (!stt.isIdentity())
543 nameOstream << stt.getDimToLvl() <<
"_";
544 nameOstream << stt.getElementType() <<
"_";
545 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
546 return nameOstream.str().str();
554class SparseReturnConverter :
public OpConversionPattern<func::ReturnOp> {
556 using OpConversionPattern::OpConversionPattern;
558 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter)
const override {
561 rewriter.replaceOpWithNewOp<func::ReturnOp>(
568class SparseCallConverter :
public OpConversionPattern<func::CallOp> {
571 using OpConversionPattern::OpConversionPattern;
573 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter)
const override {
575 Location loc = op.getLoc();
581 SmallVector<Type> finalRetTy;
582 if (
failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
587 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
590 SmallVector<SmallVector<Value>> packedResultVals;
593 unsigned retOffset = 0;
596 SmallVector<Type> sparseFlat;
597 for (
auto ret : op.getResults()) {
598 assert(retOffset < newCall.getNumResults());
599 auto retType = ret.getType();
600 if (
failed(typeConverter->convertType(retType, sparseFlat)))
601 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
604 assert(!sparseFlat.empty());
605 if (sparseFlat.size() > 1) {
606 auto flatSize = sparseFlat.size();
607 packedResultVals.emplace_back();
608 llvm::append_range(packedResultVals.back(),
609 newCall.getResults().slice(retOffset, flatSize));
610 retOffset += flatSize;
613 packedResultVals.emplace_back();
614 packedResultVals.back().push_back(newCall.getResult(retOffset));
620 assert(packedResultVals.size() == op.getNumResults());
621 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
627class SparseLvlOpConverter :
public OpConversionPattern<LvlOp> {
629 using OpConversionPattern::OpConversionPattern;
631 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
632 ConversionPatternRewriter &rewriter)
const override {
633 std::optional<int64_t> lvl = op.getConstantLvlIndex();
634 RankedTensorType srcType = op.getSource().getType();
639 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
641 rewriter.replaceOp(op, sz);
647struct SparseReorderCOOConverter :
public OpConversionPattern<ReorderCOOOp> {
648 using OpConversionPattern::OpConversionPattern;
650 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
651 ConversionPatternRewriter &rewriter)
const override {
652 Location loc = op.getLoc();
665 op.getInputCoo().getType());
676 SortOp::create(rewriter, loc, nnz, crd,
ValueRange{val}, id,
677 rewriter.getIndexAttr(0), op.getAlgorithm());
681 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
686template <
typename Op, StorageSpecifierKind kind>
687class SparseSliceGetterOpConverter :
public OpConversionPattern<Op> {
689 using OpConversionPattern<
Op>::OpConversionPattern;
690 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
693 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter)
const override {
697 op.getSlice().getType());
699 op.getDim().getZExtValue());
701 rewriter.replaceOp(op, v);
707class SparseCastConverter :
public OpConversionPattern<tensor::CastOp> {
709 using OpConversionPattern::OpConversionPattern;
711 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter)
const override {
716 if (!encDst || encDst != encSrc)
718 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
723class SparseReMapConverter :
public OpConversionPattern<ReinterpretMapOp> {
725 using OpConversionPattern::OpConversionPattern;
727 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
728 ConversionPatternRewriter &rewriter)
const override {
730 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
736class SparseTensorAllocConverter
737 :
public OpConversionPattern<bufferization::AllocTensorOp> {
739 using OpConversionPattern::OpConversionPattern;
740 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
741 MLIRContext *context,
bool enableInit)
742 : OpConversionPattern(typeConverter, context),
743 enableBufferInitialization(enableInit) {}
746 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter)
const override {
749 if (!resType.hasEncoding())
752 Location loc = op.getLoc();
756 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
757 SmallVector<Value> fields;
761 auto memrefTp = cast<MemRefType>(field.getType());
762 auto size = memref::DimOp::create(rewriter, loc, field, 0);
764 memref::AllocOp::create(rewriter, loc, memrefTp,
ValueRange{size});
765 memref::CopyOp::create(rewriter, loc, field, copied);
766 fields.push_back(copied);
771 rewriter.replaceOpWithMultiple(op, {fields});
775 if (!resType.isIdentity()) {
776 return rewriter.notifyMatchFailure(
777 op,
"try run --sparse-reinterpret-map before codegen");
780 SmallVector<Value> lvlSizesValues;
786 Value sizeHint = op.getSizeHint();
787 SmallVector<Value> fields;
789 sizeHint, lvlSizesValues, fields);
792 rewriter.replaceOpWithMultiple(op, {fields});
797 bool enableBufferInitialization;
801class SparseTensorEmptyConverter :
public OpConversionPattern<tensor::EmptyOp> {
803 using OpConversionPattern::OpConversionPattern;
804 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
805 MLIRContext *context,
bool enableInit)
806 : OpConversionPattern(typeConverter, context),
807 enableBufferInitialization(enableInit) {}
810 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
811 ConversionPatternRewriter &rewriter)
const override {
813 if (!resType.hasEncoding())
816 if (!resType.isIdentity()) {
817 return rewriter.notifyMatchFailure(
818 op,
"try run --sparse-reinterpret-map before codegen");
821 Location loc = op.getLoc();
823 SmallVector<Value> lvlSizesValues;
828 SmallVector<Value> fields;
830 sizeHint, lvlSizesValues, fields);
833 rewriter.replaceOpWithMultiple(op, {fields});
838 bool enableBufferInitialization;
842class SparseTensorDeallocConverter
843 :
public OpConversionPattern<bufferization::DeallocTensorOp> {
845 using OpConversionPattern::OpConversionPattern;
846 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
847 MLIRContext *context,
bool createDeallocs)
848 : OpConversionPattern(typeConverter, context),
849 createDeallocs(createDeallocs) {}
852 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
853 ConversionPatternRewriter &rewriter)
const override {
860 if (createDeallocs) {
862 Location loc = op.getLoc();
865 cast<RankedTensorType>(op.getTensor().getType()));
868 memref::DeallocOp::create(rewriter, loc, input);
870 rewriter.eraseOp(op);
875 const bool createDeallocs;
879class SparseTensorLoadConverter :
public OpConversionPattern<LoadOp> {
881 using OpConversionPattern::OpConversionPattern;
883 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
884 ConversionPatternRewriter &rewriter)
const override {
887 op.getTensor().getType());
889 if (op.getHasInserts())
892 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
898class SparseExpandConverter :
public OpConversionPattern<ExpandOp> {
900 using OpConversionPattern::OpConversionPattern;
902 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
903 ConversionPatternRewriter &rewriter)
const override {
906 Location loc = op->getLoc();
908 op.getTensor().getType());
910 Type eltType = srcType.getElementType();
911 Type boolType = rewriter.getIntegerType(1);
912 Type idxType = rewriter.getIndexType();
914 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
918 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
920 const auto genAlloc = [&](Type t) {
921 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
922 return memref::AllocOp::create(rewriter, loc, memTp,
ValueRange{sz});
927 Value values = genAlloc(eltType);
928 Value filled = genAlloc(boolType);
929 Value added = genAlloc(idxType);
936 linalg::FillOp::create(rewriter, loc,
939 linalg::FillOp::create(rewriter, loc,
943 assert(op.getNumResults() == 4);
944 rewriter.replaceOp(op, {values, filled, added, zero});
950class SparseCompressConverter :
public OpConversionPattern<CompressOp> {
952 using OpConversionPattern::OpConversionPattern;
954 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
955 ConversionPatternRewriter &rewriter)
const override {
956 Location loc = op->getLoc();
957 SmallVector<Value> fields;
959 op.getTensor().getType());
960 Value values = llvm::getSingleElement(adaptor.getValues());
961 Value filled = llvm::getSingleElement(adaptor.getFilled());
962 Value added = llvm::getSingleElement(adaptor.getAdded());
963 Value count = llvm::getSingleElement(adaptor.getCount());
965 Type eltType = dstType.getElementType();
969 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
970 SortOp::create(rewriter, loc, count, added,
ValueRange{},
971 rewriter.getMultiDimIdentityMap(1),
972 rewriter.getIndexAttr(0),
973 SparseTensorSortKind::HybridQuickSort);
990 Value i = loop.getInductionVar();
992 Value crd =
genLoad(rewriter, loc, added, i);
993 Value value =
genLoad(rewriter, loc, values, crd);
995 SmallVector<Type> flatSpTensorTps = llvm::to_vector(
996 llvm::map_range(desc.
getFields(), [](Value v) { return v.getType(); }));
997 SmallVector<Value> flatLvlCoords =
flattenValues(adaptor.getLvlCoords());
998 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
999 params.push_back(crd);
1000 params.push_back(value);
1001 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1003 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1006 scf::YieldOp::create(rewriter, loc, insertRet);
1008 rewriter.setInsertionPointAfter(loop);
1010 Operation *parent =
getTop(op);
1011 rewriter.setInsertionPointAfter(parent);
1012 memref::DeallocOp::create(rewriter, loc, values);
1013 memref::DeallocOp::create(rewriter, loc, filled);
1014 memref::DeallocOp::create(rewriter, loc, added);
1016 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1022class SparseInsertConverter :
public OpConversionPattern<tensor::InsertOp> {
1024 using OpConversionPattern::OpConversionPattern;
1026 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1027 ConversionPatternRewriter &rewriter)
const override {
1029 if (!stt.hasEncoding())
1031 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1033 Location loc = op.getLoc();
1037 SmallVector<Value> params = llvm::to_vector(desc.
getFields());
1038 SmallVector<Value> flatIndices =
flattenValues(adaptor.getIndices());
1039 params.append(flatIndices.begin(), flatIndices.end());
1040 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1041 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1043 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1045 rewriter.replaceOpWithMultiple(op, {ret});
1051class SparseToPositionsConverter :
public OpConversionPattern<ToPositionsOp> {
1053 using OpAdaptor = ToPositionsOp::Adaptor;
1054 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1056 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter)
const override {
1061 Location loc = op.getLoc();
1062 Level lvl = op.getLevel();
1064 op.getTensor().getType());
1067 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1073class SparseToCoordinatesConverter
1074 :
public OpConversionPattern<ToCoordinatesOp> {
1076 using OpAdaptor = ToCoordinatesOp::Adaptor;
1077 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1079 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1080 ConversionPatternRewriter &rewriter)
const override {
1084 Location loc = op.getLoc();
1085 Level lvl = op.getLevel();
1087 op.getTensor().getType());
1088 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1093 rewriter.replaceOp(op, mem);
1099class SparseToCoordinatesBufferConverter
1100 :
public OpConversionPattern<ToCoordinatesBufferOp> {
1102 using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
1103 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1105 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1106 ConversionPatternRewriter &rewriter)
const override {
1110 Location loc = op.getLoc();
1113 op.getTensor().getType());
1116 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1122class SparseToValuesConverter :
public OpConversionPattern<ToValuesOp> {
1124 using OpAdaptor = ToValuesOp::Adaptor;
1125 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1127 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1128 ConversionPatternRewriter &rewriter)
const override {
1132 Location loc = op.getLoc();
1134 op.getTensor().getType());
1137 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1143class SparseConvertConverter :
public OpConversionPattern<ConvertOp> {
1145 using OpConversionPattern::OpConversionPattern;
1147 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const override {
1150 SparseTensorEncodingAttr encSrc =
1154 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1158 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1163 Type retElemTp = op.getResult().getType().getElementType();
1164 Type srcElemTp = op.getSource().getType().getElementType();
1166 if (retElemTp == srcElemTp && encDst == encSrc) {
1167 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1179 Location loc = op.getLoc();
1181 op.getSource().getType());
1182 SmallVector<Value> fields;
1184 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1185 [&rewriter, &fields, srcDesc,
1187 LevelType ) ->
bool {
1189 if (fKind == SparseTensorFieldKind::StorageSpec) {
1190 fields.push_back(srcDesc.getSpecifier());
1193 Value srcMem = srcDesc.getMemRefField(fIdx);
1197 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1198 auto dstMem = memref::AllocOp::create(rewriter, loc,
1199 cast<MemRefType>(fTp), sz);
1200 if (fTp != srcMem.getType()) {
1203 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1204 constantIndex(rewriter, loc, 1),
1205 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1207 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1208 Value casted = genCast(builder, loc, v,
1209 dstMem.getType().getElementType());
1210 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1216 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1218 fields.push_back(dstMem);
1223 rewriter.replaceOpWithMultiple(op, {fields});
1228class SparseExtractSliceConverter
1229 :
public OpConversionPattern<tensor::ExtractSliceOp> {
1231 using OpConversionPattern::OpConversionPattern;
1233 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1234 ConversionPatternRewriter &rewriter)
const override {
1235 Location loc = op.getLoc();
1236 MLIRContext *ctx = op.getContext();
1240 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1242 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1244 SmallVector<Value> fields;
1246 op.getSource().getType());
1248 auto newSpec = StorageSpecifierInitOp::create(
1249 rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1254 for (
auto [idx, offset, size, stride] : llvm::enumerate(
1255 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1270 assert(srcEnc.isIdentity());
1280 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
1286class SparseNumberOfEntriesConverter
1287 :
public OpConversionPattern<NumberOfEntriesOp> {
1289 using OpConversionPattern::OpConversionPattern;
1291 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1292 ConversionPatternRewriter &rewriter)
const override {
1297 op.getTensor().getType());
1298 rewriter.replaceOp(op, desc.
getValMemSize(rewriter, op.getLoc()));
1303struct SparseAssembleOpConverter :
public OpConversionPattern<AssembleOp> {
1304 using OpConversionPattern::OpConversionPattern;
1306 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1307 ConversionPatternRewriter &rewriter)
const override {
1308 Location loc = op.getLoc();
1311 SmallVector<Value> fields;
1315 [&rewriter, &fields, &op, &stt,
1317 Level , LevelType lt) ->
bool {
1318 assert(fields.size() == fIdx);
1319 if (fKind == SparseTensorFieldKind::StorageSpec) {
1324 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1326 : op.getLevels()[fIdx];
1329 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1332 mem.getType(), stt.getBatchLvlRank());
1333 mem = memref::CastOp::create(
1334 rewriter, loc, fType,
1335 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1337 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1339 fields.push_back(mem);
1344 MutSparseTensorDescriptor desc(stt, fields);
1351 Level trailCOOStart = stt.getAoSCOOStart();
1352 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1354 for (
Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1355 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1358 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1359 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1362 if (lvl > trailCOOStart)
1366 LevelType lt = stt.getLvlType(lvl);
1368 if (lt.
isa<LevelFormat::Dense>()) {
1369 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1370 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1373 if (lt.
isa<LevelFormat::Batch>()) {
1383 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1384 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1388 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1394 SmallVector<Value> batched(stt.getBatchLvlRank(),
1396 batched.push_back(posBack);
1398 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1402 if (lvl == trailCOOStart) {
1403 Value cooSz = arith::MulIOp::create(
1404 rewriter, loc, memSize,
constantIndex(rewriter, loc, trailCOORank));
1412 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
1417struct SparseDisassembleOpConverter
1418 :
public OpConversionPattern<DisassembleOp> {
1419 using OpConversionPattern::OpConversionPattern;
1420 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1421 MLIRContext *context)
1422 : OpConversionPattern(typeConverter, context) {}
1425 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1426 ConversionPatternRewriter &rewriter)
const override {
1428 op.getTensor().getType());
1429 Location loc = op.getLoc();
1430 SmallVector<Value> retMem;
1431 SmallVector<Value> retLen;
1435 Level lvl, LevelType lt) ->
bool {
1436 if (fKind == SparseTensorFieldKind::StorageSpec)
1441 if (fKind == SparseTensorFieldKind::ValMemRef) {
1444 dst =
genToMemref(rewriter, loc, op.getOutValues());
1446 retMem.push_back(dst);
1447 Type valLenTp = op.getValLen().getType();
1450 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1451 fKind == SparseTensorFieldKind::CrdMemRef);
1453 sz = fKind == SparseTensorFieldKind::PosMemRef
1457 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1458 retMem.push_back(dst);
1460 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1463 Value flatOut = dst;
1464 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1467 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1471 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1476 SmallVector<Value> retValues = llvm::to_vector(
1477 llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1478 return bufferization::ToTensorOp::create(
1483 retValues.append(retLen.begin(), retLen.end());
1484 rewriter.replaceOp(op, retValues);
1489struct SparseNewConverter :
public OpConversionPattern<NewOp> {
1490 using OpConversionPattern::OpConversionPattern;
1492 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1493 ConversionPatternRewriter &rewriter)
const override {
1494 Location loc = op.getLoc();
1498 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1512 SmallVector<Value> dimSizesValues;
1513 Value dimSizesBuffer;
1514 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1515 dimSizesValues, dimSizesBuffer);
1518 const Type indexTp = rewriter.getIndexType();
1519 Value nse =
createFuncCall(rewriter, loc,
"getSparseTensorReaderNSE",
1520 {indexTp}, {reader}, EmitCInterface::Off)
1524 SmallVector<Value> lvlSizesValues;
1525 Value dim2lvlBuffer;
1526 Value lvl2dimBuffer;
1527 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1528 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1531 Value sizeHint = nse;
1532 SmallVector<Value> fields;
1534 lvlSizesValues, fields);
1537 MutSparseTensorDescriptor desc(dstTp, fields);
1540 const Type boolTp = rewriter.getIntegerType(1);
1541 const Type elemTp = dstTp.getElementType();
1542 const Type crdTp = dstTp.getCrdType();
1543 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1548 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1554 const Level lvlRank = dstTp.getLvlRank();
1555 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1556 Value kFalse =
constantI1(rewriter, loc,
false);
1557 Value notSorted = arith::CmpIOp::create(
1558 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1560 scf::IfOp::create(rewriter, loc, notSorted,
false);
1561 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1562 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1563 SortOp::create(rewriter, loc, nse, xs,
ValueRange{ys}, xPerm,
1564 rewriter.getIndexAttr(0),
1565 SparseTensorSortKind::HybridQuickSort);
1566 rewriter.setInsertionPointAfter(ifOp);
1572 const Type posTp = dstTp.getPosType();
1573 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1574 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1577 Value coordinatesSize = arith::MulIOp::create(
1585 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1586 EmitCInterface::Off);
1589 rewriter.replaceOpWithMultiple(op, {fields});
1594struct SparseHasRuntimeLibraryConverter
1595 :
public OpConversionPattern<HasRuntimeLibraryOp> {
1596 using OpConversionPattern::OpConversionPattern;
1598 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1599 ConversionPatternRewriter &rewriter)
const override {
1600 auto i1Type = rewriter.getI1Type();
1601 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1602 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1617 bool createSparseDeallocs,
bool enableBufferInitialization) {
1619 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1620 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1621 SparseCastConverter, SparseExtractSliceConverter,
1622 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1623 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1624 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1625 StorageSpecifierKind::DimOffset>,
1626 SparseSliceGetterOpConverter<ToSliceStrideOp,
1627 StorageSpecifierKind::DimStride>,
1628 SparseToPositionsConverter, SparseToCoordinatesConverter,
1629 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1630 SparseConvertConverter, SparseNewConverter,
1631 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1632 typeConverter,
patterns.getContext());
1633 patterns.add<SparseTensorDeallocConverter>(
1634 typeConverter,
patterns.getContext(), createSparseDeallocs);
1635 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1636 typeConverter,
patterns.getContext(), enableBufferInitialization);
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
bool isCompressedLT(LevelType lt)
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
unsigned FieldIndex
The type of field indices.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.