45 for (
const auto &vals : values)
46 llvm::append_range(result, vals);
53 return memref::LoadOp::create(builder, loc, mem, idx);
60 val =
genCast(builder, loc, val,
61 cast<ShapedType>(mem.
getType()).getElementType());
62 memref::StoreOp::create(builder, loc, val, mem, idx);
74 scf::ForOp::create(builder, loc, lower, upper, one, fields);
75 for (
unsigned i = 0, e = fields.size(); i < e; i++)
76 fields[i] = forOp.getRegionIterArg(i);
90 auto pushBackOp = PushBackOp::create(
92 field,
genCast(builder, loc, value, etp), repeat);
96 pushBackOp.getNewSize());
105 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
116 linear = arith::MulIOp::create(builder, loc, linear, two);
129 linear = arith::MulIOp::create(builder, loc, linear, size);
134 std::nullopt, valZero, linear);
139 MemRefType memRefType,
Value sz,
141 Value buffer = memref::AllocOp::create(builder, loc, memRefType, sz);
142 Type elemType = memRefType.getElementType();
145 linalg::FillOp::create(builder, loc, fillValue, buffer);
155 dimSizesValues.clear();
156 dimSizesValues.reserve(dimRank);
159 dimSizesValues.push_back(ShapedType::isDynamic(sz)
178 Value posHeuristic, crdHeuristic, valHeuristic;
180 valHeuristic = lvlSizesValues[0];
181 for (
Level lvl = 1; lvl < lvlRank; lvl++)
182 valHeuristic = arith::MulIOp::create(builder, loc, valHeuristic,
183 lvlSizesValues[lvl]);
184 }
else if (sizeHint) {
187 crdHeuristic = arith::MulIOp::create(
188 builder, loc,
constantIndex(builder, loc, lvlRank), sizeHint);
190 posHeuristic = arith::AddIOp::create(builder, loc, sizeHint,
192 crdHeuristic = sizeHint;
194 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
196 valHeuristic = sizeHint;
198 posHeuristic = crdHeuristic = valHeuristic =
205 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
208 assert(fields.size() == fIdx);
216 posHeuristic, enableInit);
220 crdHeuristic, enableInit);
224 valHeuristic, enableInit);
228 fields.push_back(field);
237 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
238 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
274 assert(lvl < lvlRank &&
"Level is out of bounds");
275 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
276 "Level-rank mismatch");
284 const Value pp1 = arith::AddIOp::create(builder, loc, parentPos, one);
286 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
287 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
289 const Value crdStrideC =
292 crdStrideC ? arith::DivUIOp::create(builder, loc, crdMsz, crdStrideC)
294 const Value plast = arith::SubIOp::create(
295 builder, loc,
genCast(builder, loc, pstop, indexType), one);
297 Value lt = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ult,
299 types.push_back(boolType);
300 scf::IfOp ifOp1 = scf::IfOp::create(builder, loc, types, lt,
true);
305 crdStrideC ? arith::MulIOp::create(builder, loc, plast, crdStrideC)
307 Value eq = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq,
308 genCast(builder, loc, crd, indexType),
310 scf::YieldOp::create(builder, loc, eq);
313 genStore(builder, loc, msz, positionsAtLvl, parentPos);
314 scf::YieldOp::create(builder, loc,
constantI1(builder, loc,
false));
321 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
323 types.push_back(indexType);
326 scf::IfOp ifOp2 = scf::IfOp::create(builder, loc, types, p,
true);
333 scf::YieldOp::create(builder, loc, desc.
getFields());
338 Value mszp1 = arith::AddIOp::create(builder, loc, msz, one);
339 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
343 if ((lvl + 1) < lvlRank)
347 scf::YieldOp::create(builder, loc, desc.
getFields());
353 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
354 desc.
setField(i, ifOp2.getResult(o++));
355 return ifOp2.getResult(o);
363 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
380 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
381 Value i = loop.getInductionVar();
382 Value oldv = loop.getRegionIterArg(0);
385 Value cond = arith::CmpIOp::create(
386 builder, loc, arith::CmpIPredicate::eq, newv, posZero);
387 scf::IfOp ifOp = scf::IfOp::create(builder, loc,
TypeRange(posType),
390 genStore(builder, loc, oldv, posMemRef, i);
391 scf::YieldOp::create(builder, loc, oldv);
393 scf::YieldOp::create(builder, loc, newv);
395 scf::YieldOp::create(builder, loc, ifOp.getResult(0));
408 auto memTp = llvm::cast<MemRefType>(mem.
getType());
412 if (memTp.getRank() > 1)
415 return memref::SubViewOp::create(
431 for (
unsigned i = 0; i < batchLvls; i++)
434 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
435 ret.back().push_back(i);
447 class SparseInsertGenerator
467 const Level lvlRank = stt.getLvlRank();
472 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
473 Value value = args.back();
476 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
477 const auto lt = stt.getLvlType(lvl);
488 parentPos = arith::MulIOp::create(builder, loc, parentPos, two);
491 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
505 Value mult = arith::MulIOp::create(builder, loc, size, parentPos);
506 parentPos = arith::AddIOp::create(builder, loc, mult, coords[lvl]);
510 if (!stt.isDenseLvl(lvlRank - 1))
512 std::nullopt, value);
518 std::string getMangledFuncName() {
521 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
524 llvm::raw_svector_ostream nameOstream(nameBuffer);
525 nameOstream << kInsertFuncNamePrefix;
526 const Level lvlRank = stt.getLvlRank();
527 for (
Level l = 0; l < lvlRank; l++) {
531 lvlType.begin(), lvlType.end(),
532 [](
char c) { return c ==
'(' || c ==
','; },
'_');
533 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
534 nameOstream << lvlType <<
"_";
539 for (
const auto sz : stt.getDimShape())
540 nameOstream << sz <<
"_";
542 if (!stt.isIdentity())
543 nameOstream << stt.getDimToLvl() <<
"_";
544 nameOstream << stt.getElementType() <<
"_";
545 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
546 return nameOstream.str().str();
558 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
573 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
582 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
587 func::CallOp::create(rewriter, loc, op.getCallee(), finalRetTy,
593 unsigned retOffset = 0;
597 for (
auto ret : op.getResults()) {
598 assert(retOffset < newCall.getNumResults());
599 auto retType = ret.getType();
600 if (failed(typeConverter->convertType(retType, sparseFlat)))
601 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
604 assert(!sparseFlat.empty());
605 if (sparseFlat.size() > 1) {
606 auto flatSize = sparseFlat.size();
607 packedResultVals.emplace_back();
608 llvm::append_range(packedResultVals.back(),
609 newCall.getResults().slice(retOffset, flatSize));
610 retOffset += flatSize;
613 packedResultVals.emplace_back();
614 packedResultVals.back().push_back(newCall.getResult(retOffset));
620 assert(packedResultVals.size() == op.getNumResults());
631 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
633 std::optional<int64_t> lvl = op.getConstantLvlIndex();
634 RankedTensorType srcType = op.getSource().getType();
639 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
650 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
665 op.getInputCoo().getType());
676 SortOp::create(rewriter, loc, nnz, crd,
ValueRange{val}, id,
686 template <
typename Op, StorageSpecifierKind kind>
693 matchAndRewrite(
Op op, OneToNOpAdaptor adaptor,
697 op.getSlice().getType());
699 op.getDim().getZExtValue());
711 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
716 if (!encDst || encDst != encSrc)
727 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
736 class SparseTensorAllocConverter
740 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
743 enableBufferInitialization(enableInit) {}
746 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
749 if (!resType.hasEncoding())
756 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
761 auto memrefTp = cast<MemRefType>(field.getType());
762 auto size = memref::DimOp::create(rewriter, loc, field, 0);
764 memref::AllocOp::create(rewriter, loc, memrefTp,
ValueRange{size});
765 memref::CopyOp::create(rewriter, loc, field, copied);
766 fields.push_back(copied);
775 if (!resType.isIdentity()) {
777 op,
"try run --sparse-reinterpret-map before codegen");
786 Value sizeHint = op.getSizeHint();
789 sizeHint, lvlSizesValues, fields);
797 bool enableBufferInitialization;
804 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
807 enableBufferInitialization(enableInit) {}
810 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
813 if (!resType.hasEncoding())
816 if (!resType.isIdentity()) {
818 op,
"try run --sparse-reinterpret-map before codegen");
830 sizeHint, lvlSizesValues, fields);
838 bool enableBufferInitialization;
842 class SparseTensorDeallocConverter
846 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
849 createDeallocs(createDeallocs) {}
852 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
860 if (createDeallocs) {
865 cast<RankedTensorType>(op.getTensor().getType()));
868 memref::DeallocOp::create(rewriter, loc, input);
875 const bool createDeallocs;
883 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
887 op.getTensor().getType());
889 if (op.getHasInserts())
902 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
908 op.getTensor().getType());
910 Type eltType = srcType.getElementType();
918 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
920 const auto genAlloc = [&](
Type t) {
922 return memref::AllocOp::create(rewriter, loc, memTp,
ValueRange{sz});
927 Value values = genAlloc(eltType);
928 Value filled = genAlloc(boolType);
929 Value added = genAlloc(idxType);
936 linalg::FillOp::create(rewriter, loc,
939 linalg::FillOp::create(rewriter, loc,
943 assert(op.getNumResults() == 4);
944 rewriter.
replaceOp(op, {values, filled, added, zero});
954 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
959 op.getTensor().getType());
960 Value values = llvm::getSingleElement(adaptor.getValues());
961 Value filled = llvm::getSingleElement(adaptor.getFilled());
962 Value added = llvm::getSingleElement(adaptor.getAdded());
963 Value count = llvm::getSingleElement(adaptor.getCount());
965 Type eltType = dstType.getElementType();
969 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
970 SortOp::create(rewriter, loc, count, added,
ValueRange{},
973 SparseTensorSortKind::HybridQuickSort);
990 Value i = loop.getInductionVar();
996 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
998 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
999 params.push_back(crd);
1000 params.push_back(value);
1001 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1006 scf::YieldOp::create(rewriter, loc, insertRet);
1012 memref::DeallocOp::create(rewriter, loc, values);
1013 memref::DeallocOp::create(rewriter, loc, filled);
1014 memref::DeallocOp::create(rewriter, loc, added);
1026 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1029 if (!stt.hasEncoding())
1031 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1039 params.append(flatIndices.begin(), flatIndices.end());
1040 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1041 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1053 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1056 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1062 Level lvl = op.getLevel();
1064 op.getTensor().getType());
1073 class SparseToCoordinatesConverter
1076 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1079 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1085 Level lvl = op.getLevel();
1087 op.getTensor().getType());
1088 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1099 class SparseToCoordinatesBufferConverter
1102 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1105 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1113 op.getTensor().getType());
1124 using OpAdaptor =
typename ToValuesOp::Adaptor;
1127 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1134 op.getTensor().getType());
1147 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1150 SparseTensorEncodingAttr encSrc =
1154 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1158 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1163 Type retElemTp = op.getResult().getType().getElementType();
1164 Type srcElemTp = op.getSource().getType().getElementType();
1166 if (retElemTp == srcElemTp && encDst == encSrc) {
1181 op.getSource().getType());
1185 [&rewriter, &fields, srcDesc,
1189 if (fKind == SparseTensorFieldKind::StorageSpec) {
1190 fields.push_back(srcDesc.getSpecifier());
1193 Value srcMem = srcDesc.getMemRefField(fIdx);
1197 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1198 auto dstMem = memref::AllocOp::create(rewriter, loc,
1199 cast<MemRefType>(fTp), sz);
1200 if (fTp != srcMem.getType()) {
1203 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1204 constantIndex(rewriter, loc, 1),
1205 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1207 Value v = memref::LoadOp::create(builder, loc, srcMem, ivs);
1208 Value casted = genCast(builder, loc, v,
1209 dstMem.getType().getElementType());
1210 memref::StoreOp::create(builder, loc, casted, dstMem, ivs);
1216 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1218 fields.push_back(dstMem);
1228 class SparseExtractSliceConverter
1233 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1240 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1242 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1246 op.getSource().getType());
1248 auto newSpec = StorageSpecifierInitOp::create(
1255 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1270 assert(srcEnc.isIdentity());
1286 class SparseNumberOfEntriesConverter
1291 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1297 op.getTensor().getType());
1306 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1315 [&rewriter, &fields, &op, &stt,
1318 assert(fields.size() == fIdx);
1319 if (fKind == SparseTensorFieldKind::StorageSpec) {
1321 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1324 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1326 : op.getLevels()[fIdx];
1329 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1332 mem.getType(), stt.getBatchLvlRank());
1333 mem = memref::CastOp::create(
1334 rewriter, loc, fType,
1335 memref::CollapseShapeOp::create(rewriter, loc, mem, reassoc));
1337 mem = memref::CastOp::create(rewriter, loc, fType, mem);
1339 fields.push_back(mem);
1351 Level trailCOOStart = stt.getAoSCOOStart();
1352 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1354 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1355 assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
1358 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1359 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1362 if (lvl > trailCOOStart)
1368 if (lt.
isa<LevelFormat::Dense>()) {
1369 memSize = arith::MulIOp::create(rewriter, loc, lvlSize, memSize);
1370 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1373 if (lt.
isa<LevelFormat::Batch>()) {
1383 memSize = arith::MulIOp::create(rewriter, loc, memSize, c2);
1384 posBack = arith::SubIOp::create(rewriter, loc, memSize, c1);
1388 memSize = arith::AddIOp::create(rewriter, loc, memSize, c1);
1396 batched.push_back(posBack);
1398 posBack = arith::SubIOp::create(rewriter, loc, posBack, c1);
1402 if (lvl == trailCOOStart) {
1403 Value cooSz = arith::MulIOp::create(
1404 rewriter, loc, memSize,
constantIndex(rewriter, loc, trailCOORank));
1417 struct SparseDisassembleOpConverter
1420 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1425 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1428 op.getTensor().getType());
1436 if (fKind == SparseTensorFieldKind::StorageSpec)
1441 if (fKind == SparseTensorFieldKind::ValMemRef) {
1444 dst =
genToMemref(rewriter, loc, op.getOutValues());
1446 retMem.push_back(dst);
1447 Type valLenTp = op.getValLen().getType();
1450 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1451 fKind == SparseTensorFieldKind::CrdMemRef);
1453 sz = fKind == SparseTensorFieldKind::PosMemRef
1457 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1458 retMem.push_back(dst);
1460 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1463 Value flatOut = dst;
1464 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1467 flatOut = memref::CollapseShapeOp::create(rewriter, loc, dst, reassoc);
1471 memref::CopyOp::create(rewriter, loc, srcMem, dstMem);
1477 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1478 return bufferization::ToTensorOp::create(
1483 retValues.append(retLen.begin(), retLen.end());
1492 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1498 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1513 Value dimSizesBuffer;
1514 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1515 dimSizesValues, dimSizesBuffer);
1520 {indexTp}, {reader}, EmitCInterface::Off)
1525 Value dim2lvlBuffer;
1526 Value lvl2dimBuffer;
1527 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1528 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1531 Value sizeHint = nse;
1534 lvlSizesValues, fields);
1541 const Type elemTp = dstTp.getElementType();
1542 const Type crdTp = dstTp.getCrdType();
1543 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1548 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1554 const Level lvlRank = dstTp.getLvlRank();
1555 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1557 Value notSorted = arith::CmpIOp::create(
1558 rewriter, loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1560 scf::IfOp::create(rewriter, loc, notSorted,
false);
1563 SortOp::create(rewriter, loc, nse, xs,
ValueRange{ys}, xPerm,
1565 SparseTensorSortKind::HybridQuickSort);
1572 const Type posTp = dstTp.getPosType();
1573 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1574 memref::StoreOp::create(rewriter, loc, posNse, posMemref0, c1);
1577 Value coordinatesSize = arith::MulIOp::create(
1585 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1586 EmitCInterface::Off);
1594 struct SparseHasRuntimeLibraryConverter
1598 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1617 bool createSparseDeallocs,
bool enableBufferInitialization) {
1619 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1620 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1621 SparseCastConverter, SparseExtractSliceConverter,
1622 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1623 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1624 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1625 StorageSpecifierKind::DimOffset>,
1626 SparseSliceGetterOpConverter<ToSliceStrideOp,
1627 StorageSpecifierKind::DimStride>,
1628 SparseToPositionsConverter, SparseToCoordinatesConverter,
1629 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1630 SparseConvertConverter, SparseNewConverter,
1631 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1632 typeConverter,
patterns.getContext());
1633 patterns.add<SparseTensorDeallocConverter>(
1634 typeConverter,
patterns.getContext(), createSparseDeallocs);
1635 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1636 typeConverter,
patterns.getContext(), enableBufferInitialization);
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.