32#include "llvm/ADT/SmallVectorExtras.h"
46 for (
const auto &vals : values)
47 llvm::append_range(
result, vals);
54 return memref::LoadOp::create(builder, loc, mem, idx);
61 val =
genCast(builder, loc, val,
62 cast<ShapedType>(mem.
getType()).getElementType());
63 memref::StoreOp::create(builder, loc, val, mem, idx);
75 scf::ForOp::create(builder, loc, lower, upper, one, fields);
76 for (
unsigned i = 0, e = fields.size(); i < e; i++)
77 fields[i] = forOp.getRegionIterArg(i);
91 auto pushBackOp = PushBackOp::create(
93 field,
genCast(builder, loc, value, etp), repeat);
97 pushBackOp.getNewSize());
106 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
117 linear = arith::MulIOp::create(builder, loc, linear, two);
130 linear = arith::MulIOp::create(builder, loc, linear, size);
135 std::nullopt, valZero, linear);
140 MemRefType memRefType,
Value sz,
142 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
143 Type elemType = memRefType.getElementType();
146 linalg::FillOp::create(builder, loc, fillValue, buffer);
156 dimSizesValues.clear();
157 dimSizesValues.reserve(dimRank);
160 dimSizesValues.push_back(ShapedType::isDynamic(sz)
179 Value posHeuristic, crdHeuristic, valHeuristic;
181 valHeuristic = lvlSizesValues[0];
182 for (
Level lvl = 1; lvl < lvlRank; lvl++)
183 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
184 lvlSizesValues[lvl]);
185 }
else if (sizeHint) {
188 crdHeuristic = arith::MulIOp::create(
189 builder, loc,
constantIndex(builder, loc, lvlRank), sizeHint);
191 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
193 crdHeuristic = sizeHint;
195 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
197 valHeuristic = sizeHint;
199 posHeuristic = crdHeuristic = valHeuristic =
206 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
209 assert(fields.size() == fIdx);
217 posHeuristic, enableInit);
221 crdHeuristic, enableInit);
225 valHeuristic, enableInit);
229 fields.push_back(field);
238 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
239 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
275 assert(lvl < lvlRank &&
"Level is out of bounds");
276 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
277 "Level-rank mismatch");
285 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
287 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
288 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
290 const Value crdStrideC =
293 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
295 const Value plast = arith::SubIOp::create(
296 builder, loc,
genCast(builder, loc, pstop, indexType), one);
298 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
300 types.push_back(boolType);
301 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt,
true);
306 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
308 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
309 genCast(builder, loc, crd, indexType),
311 scf::YieldOp::create(builder, loc, eq);
314 genStore(builder, loc, msz, positionsAtLvl, parentPos);
315 scf::YieldOp::create(builder, loc,
constantI1(builder, loc,
false));
322 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
324 types.push_back(indexType);
327 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p,
true);
334 scf::YieldOp::create(builder, loc, desc.
getFields());
339 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
340 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
344 if ((lvl + 1) < lvlRank)
348 scf::YieldOp::create(builder, loc, desc.
getFields());
354 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
355 desc.
setField(i, ifOp2.getResult(o++));
356 return ifOp2.getResult(o);
364 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
381 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
382 Value i = loop.getInductionVar();
383 Value oldv = loop.getRegionIterArg(0);
386 Value cond = arith::CmpIOp::create(
387 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
388 scf::IfOp ifOp = scf::IfOp::create(builder, loc,
TypeRange(posType),
391 genStore(builder, loc, oldv, posMemRef, i);
392 scf::YieldOp::create(builder, loc, oldv);
394 scf::YieldOp::create(builder, loc, newv);
396 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
409 auto memTp = llvm::cast<MemRefType>(mem.
getType());
413 if (memTp.getRank() > 1)
416 return memref::SubViewOp::create(
418 MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
432 for (
unsigned i = 0; i < batchLvls; i++)
435 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
436 ret.back().push_back(i);
448class SparseInsertGenerator
453 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {};
466 OpBuilder &builder, Location loc) {
467 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
468 const Level lvlRank = stt.getLvlRank();
470 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
471 MutSparseTensorDescriptor desc(stt, fields);
472 const SmallVector<Value> coords =
473 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
474 Value value = args.back();
477 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
478 const auto lt = stt.getLvlType(lvl);
489 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
492 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
498 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
505 Value size = desc.
getLvlSize(builder, loc, lvl);
506 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
507 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
511 if (!stt.isDenseLvl(lvlRank - 1))
512 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
513 std::nullopt, value);
519 std::string getMangledFuncName() {
522 constexpr const char kInsertFuncNamePrefix[] =
"_insert_";
523 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
524 SmallString<32> nameBuffer;
525 llvm::raw_svector_ostream nameOstream(nameBuffer);
526 nameOstream << kInsertFuncNamePrefix;
527 const Level lvlRank = stt.getLvlRank();
528 for (
Level l = 0; l < lvlRank; l++) {
532 lvlType.begin(), lvlType.end(),
533 [](
char c) { return c ==
'(' || c ==
','; },
'_');
534 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
535 nameOstream << lvlType <<
"_";
540 for (
const auto sz : stt.getDimShape())
541 nameOstream << sz <<
"_";
543 if (!stt.isIdentity())
544 nameOstream << stt.getDimToLvl() <<
"_";
545 nameOstream << stt.getElementType() <<
"_";
546 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
547 return nameOstream.str().str();
555class SparseReturnConverter :
public OpConversionPattern<func::ReturnOp> {
557 using OpConversionPattern::OpConversionPattern;
559 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override {
562 rewriter.replaceOpWithNewOp<func::ReturnOp>(
569class SparseCallConverter :
public OpConversionPattern<func::CallOp> {
572 using OpConversionPattern::OpConversionPattern;
574 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter)
const override {
576 Location loc = op.getLoc();
582 SmallVector<Type> finalRetTy;
583 if (
failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
588 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
591 SmallVector<SmallVector<Value>> packedResultVals;
594 unsigned retOffset = 0;
597 SmallVector<Type> sparseFlat;
598 for (
auto ret : op.getResults()) {
599 assert(retOffset < newCall.getNumResults());
600 auto retType = ret.getType();
601 if (
failed(typeConverter->convertType(retType, sparseFlat)))
602 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
605 assert(!sparseFlat.empty());
606 if (sparseFlat.size() > 1) {
607 auto flatSize = sparseFlat.size();
608 packedResultVals.emplace_back();
609 llvm::append_range(packedResultVals.back(),
610 newCall.getResults().slice(retOffset, flatSize));
611 retOffset += flatSize;
614 packedResultVals.emplace_back();
615 packedResultVals.back().push_back(newCall.getResult(retOffset));
621 assert(packedResultVals.size() == op.getNumResults());
622 rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
628class SparseLvlOpConverter :
public OpConversionPattern<LvlOp> {
630 using OpConversionPattern::OpConversionPattern;
632 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 std::optional<int64_t> lvl = op.getConstantLvlIndex();
635 RankedTensorType srcType = op.getSource().getType();
640 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
642 rewriter.replaceOp(op, sz);
648struct SparseReorderCOOConverter :
public OpConversionPattern<ReorderCOOOp> {
649 using OpConversionPattern::OpConversionPattern;
651 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
652 ConversionPatternRewriter &rewriter)
const override {
653 Location loc = op.getLoc();
666 op.getInputCoo().getType());
677 SortOp::create(rewriter, loc, nnz, crd,
ValueRange{val}, id,
678 rewriter.getIndexAttr(0), op.getAlgorithm());
682 rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
687template <
typename Op, StorageSpecifierKind kind>
688class SparseSliceGetterOpConverter :
public OpConversionPattern<Op> {
690 using OpConversionPattern<
Op>::OpConversionPattern;
691 using typename OpConversionPattern<Op>::OneToNOpAdaptor;
694 matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
695 ConversionPatternRewriter &rewriter)
const override {
698 op.getSlice().getType());
700 op.getDim().getZExtValue());
702 rewriter.replaceOp(op, v);
708class SparseCastConverter :
public OpConversionPattern<tensor::CastOp> {
710 using OpConversionPattern::OpConversionPattern;
712 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
713 ConversionPatternRewriter &rewriter)
const override {
717 if (!encDst || encDst != encSrc)
719 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
724class SparseReMapConverter :
public OpConversionPattern<ReinterpretMapOp> {
726 using OpConversionPattern::OpConversionPattern;
728 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
729 ConversionPatternRewriter &rewriter)
const override {
731 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
737class SparseTensorAllocConverter
738 :
public OpConversionPattern<bufferization::AllocTensorOp> {
740 using OpConversionPattern::OpConversionPattern;
741 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
742 MLIRContext *context,
bool enableInit)
743 : OpConversionPattern(typeConverter, context),
744 enableBufferInitialization(enableInit) {}
747 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter)
const override {
750 if (!resType.hasEncoding())
753 Location loc = op.getLoc();
757 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
758 SmallVector<Value> fields;
762 auto memrefTp = cast<MemRefType>(field.getType());
763 auto size = memref::DimOp::create(rewriter, loc, field, 0);
765 memref::AllocOp::create(rewriter, loc, memrefTp,
ValueRange{size});
766 memref::CopyOp::create(rewriter, loc, field, copied);
767 fields.push_back(copied);
772 rewriter.replaceOpWithMultiple(op, {fields});
776 if (!resType.isIdentity()) {
777 return rewriter.notifyMatchFailure(
778 op,
"try run --sparse-reinterpret-map before codegen");
781 SmallVector<Value> lvlSizesValues;
787 Value sizeHint = op.getSizeHint();
788 SmallVector<Value> fields;
790 sizeHint, lvlSizesValues, fields);
793 rewriter.replaceOpWithMultiple(op, {fields});
798 bool enableBufferInitialization;
802class SparseTensorEmptyConverter :
public OpConversionPattern<tensor::EmptyOp> {
804 using OpConversionPattern::OpConversionPattern;
805 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
806 MLIRContext *context,
bool enableInit)
807 : OpConversionPattern(typeConverter, context),
808 enableBufferInitialization(enableInit) {}
811 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
812 ConversionPatternRewriter &rewriter)
const override {
814 if (!resType.hasEncoding())
817 if (!resType.isIdentity()) {
818 return rewriter.notifyMatchFailure(
819 op,
"try run --sparse-reinterpret-map before codegen");
822 Location loc = op.getLoc();
824 SmallVector<Value> lvlSizesValues;
829 SmallVector<Value> fields;
831 sizeHint, lvlSizesValues, fields);
834 rewriter.replaceOpWithMultiple(op, {fields});
839 bool enableBufferInitialization;
843class SparseTensorDeallocConverter
844 :
public OpConversionPattern<bufferization::DeallocTensorOp> {
846 using OpConversionPattern::OpConversionPattern;
847 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
848 MLIRContext *context,
bool createDeallocs)
849 : OpConversionPattern(typeConverter, context),
850 createDeallocs(createDeallocs) {}
853 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter)
const override {
861 if (createDeallocs) {
863 Location loc = op.getLoc();
866 cast<RankedTensorType>(op.getTensor().getType()));
869 memref::DeallocOp::create(rewriter, loc, input);
871 rewriter.eraseOp(op);
876 const bool createDeallocs;
880class SparseTensorLoadConverter :
public OpConversionPattern<LoadOp> {
882 using OpConversionPattern::OpConversionPattern;
884 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
885 ConversionPatternRewriter &rewriter)
const override {
888 op.getTensor().getType());
890 if (op.getHasInserts())
893 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
899class SparseExpandConverter :
public OpConversionPattern<ExpandOp> {
901 using OpConversionPattern::OpConversionPattern;
903 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
904 ConversionPatternRewriter &rewriter)
const override {
907 Location loc = op->getLoc();
909 op.getTensor().getType());
911 Type eltType = srcType.getElementType();
912 Type boolType = rewriter.getIntegerType(1);
913 Type idxType = rewriter.getIndexType();
915 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
919 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
921 const auto genAlloc = [&](Type t) {
922 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
923 return memref::AllocOp::create(rewriter, loc, memTp,
ValueRange{sz});
928 Value values = genAlloc(eltType);
929 Value filled = genAlloc(boolType);
930 Value added = genAlloc(idxType);
937 linalg::FillOp::create(rewriter, loc,
940 linalg::FillOp::create(rewriter, loc,
944 assert(op.getNumResults() == 4);
945 rewriter.replaceOp(op, {values, filled, added, zero});
951class SparseCompressConverter :
public OpConversionPattern<CompressOp> {
953 using OpConversionPattern::OpConversionPattern;
955 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
956 ConversionPatternRewriter &rewriter)
const override {
957 Location loc = op->getLoc();
958 SmallVector<Value> fields;
960 op.getTensor().getType());
961 Value values = llvm::getSingleElement(adaptor.getValues());
962 Value filled = llvm::getSingleElement(adaptor.getFilled());
963 Value added = llvm::getSingleElement(adaptor.getAdded());
964 Value count = llvm::getSingleElement(adaptor.getCount());
966 Type eltType = dstType.getElementType();
970 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
971 SortOp::create(rewriter, loc, count, added,
ValueRange{},
972 rewriter.getMultiDimIdentityMap(1),
973 rewriter.getIndexAttr(0),
974 SparseTensorSortKind::HybridQuickSort);
991 Value i = loop.getInductionVar();
993 Value crd =
genLoad(rewriter, loc, added, i);
994 Value value =
genLoad(rewriter, loc, values, crd);
996 SmallVector<Type> flatSpTensorTps = llvm::map_to_vector(
997 desc.
getFields(), [](Value v) { return v.getType(); });
998 SmallVector<Value> flatLvlCoords =
flattenValues(adaptor.getLvlCoords());
999 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1000 params.push_back(crd);
1001 params.push_back(value);
1002 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1004 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
1007 scf::YieldOp::create(rewriter, loc, insertRet);
1009 rewriter.setInsertionPointAfter(loop);
1011 Operation *parent =
getTop(op);
1012 rewriter.setInsertionPointAfter(parent);
1013 memref::DeallocOp::create(rewriter, loc, values);
1014 memref::DeallocOp::create(rewriter, loc, filled);
1015 memref::DeallocOp::create(rewriter, loc, added);
1017 rewriter.replaceOpWithMultiple(op, {loop->getResults()});
1023class SparseInsertConverter :
public OpConversionPattern<tensor::InsertOp> {
1025 using OpConversionPattern::OpConversionPattern;
1027 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter)
const override {
1030 if (!stt.hasEncoding())
1032 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1034 Location loc = op.getLoc();
1038 SmallVector<Value> params = llvm::to_vector(desc.
getFields());
1039 SmallVector<Value> flatIndices =
flattenValues(adaptor.getIndices());
1040 params.append(flatIndices.begin(), flatIndices.end());
1041 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1042 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1044 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1046 rewriter.replaceOpWithMultiple(op, {ret});
1052class SparseToPositionsConverter :
public OpConversionPattern<ToPositionsOp> {
1054 using OpAdaptor = ToPositionsOp::Adaptor;
1055 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1057 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1058 ConversionPatternRewriter &rewriter)
const override {
1062 Location loc = op.getLoc();
1063 Level lvl = op.getLevel();
1065 op.getTensor().getType());
1068 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1074class SparseToCoordinatesConverter
1075 :
public OpConversionPattern<ToCoordinatesOp> {
1077 using OpAdaptor = ToCoordinatesOp::Adaptor;
1078 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1080 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1081 ConversionPatternRewriter &rewriter)
const override {
1085 Location loc = op.getLoc();
1086 Level lvl = op.getLevel();
1088 op.getTensor().getType());
1089 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1094 rewriter.replaceOp(op, mem);
1100class SparseToCoordinatesBufferConverter
1101 :
public OpConversionPattern<ToCoordinatesBufferOp> {
1103 using OpAdaptor = ToCoordinatesBufferOp::Adaptor;
1104 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1106 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1107 ConversionPatternRewriter &rewriter)
const override {
1111 Location loc = op.getLoc();
1114 op.getTensor().getType());
1117 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1123class SparseToValuesConverter :
public OpConversionPattern<ToValuesOp> {
1125 using OpAdaptor = ToValuesOp::Adaptor;
1126 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1128 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1129 ConversionPatternRewriter &rewriter)
const override {
1133 Location loc = op.getLoc();
1135 op.getTensor().getType());
1138 rewriter.replaceOp(op,
genSliceToSize(rewriter, loc, mem, size));
1144class SparseConvertConverter :
public OpConversionPattern<ConvertOp> {
1146 using OpConversionPattern::OpConversionPattern;
1148 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1149 ConversionPatternRewriter &rewriter)
const override {
1151 SparseTensorEncodingAttr encSrc =
1155 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1159 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1164 Type retElemTp = op.getResult().getType().getElementType();
1165 Type srcElemTp = op.getSource().getType().getElementType();
1167 if (retElemTp == srcElemTp && encDst == encSrc) {
1168 rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
1180 Location loc = op.getLoc();
1182 op.getSource().getType());
1183 SmallVector<Value> fields;
1185 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1186 [&rewriter, &fields, srcDesc,
1188 LevelType ) ->
bool {
1190 if (fKind == SparseTensorFieldKind::StorageSpec) {
1191 fields.push_back(srcDesc.getSpecifier());
1194 Value srcMem = srcDesc.getMemRefField(fIdx);
1198 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1199 auto dstMem = memref::AllocOp::create(rewriter, loc,
1200 cast<MemRefType>(fTp), sz);
1201 if (fTp != srcMem.getType()) {
1204 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1205 constantIndex(rewriter, loc, 1),
1206 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1208 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1209 Value casted = genCast(builder, loc, v,
1210 dstMem.getType().getElementType());
1211 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1217 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1219 fields.push_back(dstMem);
1224 rewriter.replaceOpWithMultiple(op, {fields});
1229class SparseExtractSliceConverter
1230 :
public OpConversionPattern<tensor::ExtractSliceOp> {
1232 using OpConversionPattern::OpConversionPattern;
1234 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1235 ConversionPatternRewriter &rewriter)
const override {
1236 Location loc = op.getLoc();
1237 MLIRContext *ctx = op.getContext();
1241 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1243 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1245 SmallVector<Value> fields;
1247 op.getSource().getType());
1249 auto newSpec = StorageSpecifierInitOp::create(
1250 rewriter, loc, StorageSpecifierType::get(ctx, dstEnc),
1255 for (
auto [idx, offset, size, stride] : llvm::enumerate(
1256 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1271 assert(srcEnc.isIdentity());
1281 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
1287class SparseNumberOfEntriesConverter
1288 :
public OpConversionPattern<NumberOfEntriesOp> {
1290 using OpConversionPattern::OpConversionPattern;
1292 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1293 ConversionPatternRewriter &rewriter)
const override {
1298 op.getTensor().getType());
1299 rewriter.replaceOp(op, desc.
getValMemSize(rewriter, op.getLoc()));
1304struct SparseAssembleOpConverter :
public OpConversionPattern<AssembleOp> {
1305 using OpConversionPattern::OpConversionPattern;
1307 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1308 ConversionPatternRewriter &rewriter)
const override {
1309 Location loc = op.getLoc();
1312 SmallVector<Value> fields;
1316 [&rewriter, &fields, &op, &stt,
1318 Level , LevelType lt) ->
bool {
1319 assert(fields.size() == fIdx);
1320 if (fKind == SparseTensorFieldKind::StorageSpec) {
1325 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1327 : op.getLevels()[fIdx];
1330 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1333 mem.getType(), stt.getBatchLvlRank());
1334 mem = memref::CastOp::create(
1335 rewriter, loc, fType,
1336 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1338 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1340 fields.push_back(mem);
1345 MutSparseTensorDescriptor desc(stt, fields);
1352 Level trailCOOStart = stt.getAoSCOOStart();
1353 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1355 for (
Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1356 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1359 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1360 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1363 if (lvl > trailCOOStart)
1367 LevelType lt = stt.getLvlType(lvl);
1369 if (lt.
isa<LevelFormat::Dense>()) {
1370 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1371 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1374 if (lt.
isa<LevelFormat::Batch>()) {
1384 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1385 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1389 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1395 SmallVector<Value> batched(stt.getBatchLvlRank(),
1397 batched.push_back(posBack);
1399 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1403 if (lvl == trailCOOStart) {
1404 Value cooSz = arith::MulIOp::create(
1405 rewriter, loc, memSize,
constantIndex(rewriter, loc, trailCOORank));
1413 rewriter.replaceOpWithMultiple(op, {desc.
getFields()});
1418struct SparseDisassembleOpConverter
1419 :
public OpConversionPattern<DisassembleOp> {
1420 using OpConversionPattern::OpConversionPattern;
1421 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1422 MLIRContext *context)
1423 : OpConversionPattern(typeConverter, context) {}
1426 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1427 ConversionPatternRewriter &rewriter)
const override {
1429 op.getTensor().getType());
1430 Location loc = op.getLoc();
1431 SmallVector<Value> retMem;
1432 SmallVector<Value> retLen;
1436 Level lvl, LevelType lt) ->
bool {
1437 if (fKind == SparseTensorFieldKind::StorageSpec)
1442 if (fKind == SparseTensorFieldKind::ValMemRef) {
1445 dst =
genToMemref(rewriter, loc, op.getOutValues());
1447 retMem.push_back(dst);
1448 Type valLenTp = op.getValLen().getType();
1451 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1452 fKind == SparseTensorFieldKind::CrdMemRef);
1454 sz = fKind == SparseTensorFieldKind::PosMemRef
1458 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1459 retMem.push_back(dst);
1461 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1464 Value flatOut = dst;
1465 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1468 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1472 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1477 SmallVector<Value> retValues =
1478 llvm::map_to_vector(retMem, [&rewriter, loc](Value v) -> Value {
1479 return bufferization::ToTensorOp::create(
1484 retValues.append(retLen.begin(), retLen.end());
1485 rewriter.replaceOp(op, retValues);
1490struct SparseNewConverter :
public OpConversionPattern<NewOp> {
1491 using OpConversionPattern::OpConversionPattern;
1493 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1494 ConversionPatternRewriter &rewriter)
const override {
1495 Location loc = op.getLoc();
1499 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1513 SmallVector<Value> dimSizesValues;
1514 Value dimSizesBuffer;
1515 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1516 dimSizesValues, dimSizesBuffer);
1519 const Type indexTp = rewriter.getIndexType();
1520 Value nse =
createFuncCall(rewriter, loc,
"getSparseTensorReaderNSE",
1521 {indexTp}, {reader}, EmitCInterface::Off)
1525 SmallVector<Value> lvlSizesValues;
1526 Value dim2lvlBuffer;
1527 Value lvl2dimBuffer;
1528 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1529 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1532 Value sizeHint = nse;
1533 SmallVector<Value> fields;
1535 lvlSizesValues, fields);
1538 MutSparseTensorDescriptor desc(dstTp, fields);
1541 const Type boolTp = rewriter.getIntegerType(1);
1542 const Type elemTp = dstTp.getElementType();
1543 const Type crdTp = dstTp.getCrdType();
1544 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1549 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1555 const Level lvlRank = dstTp.getLvlRank();
1556 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1557 Value kFalse =
constantI1(rewriter, loc,
false);
1558 Value notSorted = arith::CmpIOp::create(
1559 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1561 scf::IfOp::create(rewriter, loc, notSorted,
false);
1562 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1563 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1564 SortOp::create(rewriter, loc, nse, xs,
ValueRange{ys}, xPerm,
1565 rewriter.getIndexAttr(0),
1566 SparseTensorSortKind::HybridQuickSort);
1567 rewriter.setInsertionPointAfter(ifOp);
1573 const Type posTp = dstTp.getPosType();
1574 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1575 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1578 Value coordinatesSize = arith::MulIOp::create(
1586 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1587 EmitCInterface::Off);
1590 rewriter.replaceOpWithMultiple(op, {fields});
1595struct SparseHasRuntimeLibraryConverter
1596 :
public OpConversionPattern<HasRuntimeLibraryOp> {
1597 using OpConversionPattern::OpConversionPattern;
1599 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1600 ConversionPatternRewriter &rewriter)
const override {
1601 auto i1Type = rewriter.getI1Type();
1602 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1603 op, i1Type, rewriter.getIntegerAttr(i1Type, 0));
1618 bool createSparseDeallocs,
bool enableBufferInitialization) {
1620 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1621 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1622 SparseCastConverter, SparseExtractSliceConverter,
1623 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1624 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1625 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1626 StorageSpecifierKind::DimOffset>,
1627 SparseSliceGetterOpConverter<ToSliceStrideOp,
1628 StorageSpecifierKind::DimStride>,
1629 SparseToPositionsConverter, SparseToCoordinatesConverter,
1630 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1631 SparseConvertConverter, SparseNewConverter,
1632 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1633 typeConverter,
patterns.getContext());
1634 patterns.add<SparseTensorDeallocConverter>(
1635 typeConverter,
patterns.getContext(), createSparseDeallocs);
1636 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1637 typeConverter,
patterns.getContext(), enableBufferInitialization);
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext() const
Return the context this location is uniqued in.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
bool isSingletonLT(LevelType lt)
bool isCompressedLT(LevelType lt)
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
unsigned FieldIndex
The type of field indices.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.