49 for (
auto operand : operands) {
56 flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
58 flattened.push_back(operand);
66 return builder.
create<memref::LoadOp>(loc, mem, idx);
73 val =
genCast(builder, loc, val,
74 cast<ShapedType>(mem.
getType()).getElementType());
75 builder.
create<memref::StoreOp>(loc, val, mem, idx);
86 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lower, upper, one, fields);
87 for (
unsigned i = 0, e = fields.size(); i < e; i++)
88 fields[i] = forOp.getRegionIterArg(i);
102 auto pushBackOp = builder.
create<PushBackOp>(
104 genCast(builder, loc, value, etp), repeat);
108 pushBackOp.getNewSize());
117 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
128 linear = builder.
create<arith::MulIOp>(loc, linear, two);
141 linear = builder.
create<arith::MulIOp>(loc, linear, size);
146 std::nullopt, valZero, linear);
151 MemRefType memRefType,
Value sz,
153 Value buffer = builder.
create<memref::AllocOp>(loc, memRefType, sz);
154 Type elemType = memRefType.getElementType();
157 builder.
create<linalg::FillOp>(loc, fillValue, buffer);
167 dimSizesValues.clear();
168 dimSizesValues.reserve(dimRank);
171 dimSizesValues.push_back(ShapedType::isDynamic(sz)
190 Value posHeuristic, crdHeuristic, valHeuristic;
192 valHeuristic = lvlSizesValues[0];
193 for (
Level lvl = 1; lvl < lvlRank; lvl++)
195 builder.
create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196 }
else if (sizeHint) {
199 crdHeuristic = builder.
create<arith::MulIOp>(
202 posHeuristic = builder.
create<arith::AddIOp>(
204 crdHeuristic = sizeHint;
206 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
208 valHeuristic = sizeHint;
210 posHeuristic = crdHeuristic = valHeuristic =
217 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
220 assert(fields.size() == fIdx);
228 posHeuristic, enableInit);
232 crdHeuristic, enableInit);
236 valHeuristic, enableInit);
240 fields.push_back(field);
249 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
250 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
286 assert(lvl < lvlRank &&
"Level is out of bounds");
287 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
288 "Level-rank mismatch");
296 const Value pp1 = builder.
create<arith::AddIOp>(loc, parentPos, one);
298 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
299 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
301 const Value crdStrideC =
304 crdStrideC ? builder.
create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
307 loc,
genCast(builder, loc, pstop, indexType), one);
309 Value lt = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
311 types.push_back(boolType);
312 scf::IfOp ifOp1 = builder.
create<scf::IfOp>(loc, types, lt,
true);
317 crdStrideC ? builder.
create<arith::MulIOp>(loc, plast, crdStrideC)
320 loc, arith::CmpIPredicate::eq,
genCast(builder, loc, crd, indexType),
322 builder.
create<scf::YieldOp>(loc, eq);
325 genStore(builder, loc, msz, positionsAtLvl, parentPos);
333 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
335 types.push_back(indexType);
338 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(loc, types, p,
true);
350 Value mszp1 = builder.
create<arith::AddIOp>(loc, msz, one);
351 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
355 if ((lvl + 1) < lvlRank)
365 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
366 desc.
setField(i, ifOp2.getResult(o++));
367 return ifOp2.getResult(o);
375 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
392 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
393 Value i = loop.getInductionVar();
394 Value oldv = loop.getRegionIterArg(0);
398 loc, arith::CmpIPredicate::eq, newv, posZero);
402 genStore(builder, loc, oldv, posMemRef, i);
403 builder.
create<scf::YieldOp>(loc, oldv);
405 builder.
create<scf::YieldOp>(loc, newv);
407 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
420 auto memTp = llvm::cast<MemRefType>(mem.
getType());
424 if (memTp.getRank() > 1)
428 .
create<memref::SubViewOp>(
443 for (
unsigned i = 0; i < batchLvls; i++)
446 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
447 ret.back().push_back(i);
459 class SparseInsertGenerator
479 const Level lvlRank = stt.getLvlRank();
484 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
485 Value value = args.back();
488 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
489 const auto lt = stt.getLvlType(lvl);
500 parentPos = builder.
create<arith::MulIOp>(loc, parentPos, two);
503 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
517 Value mult = builder.
create<arith::MulIOp>(loc, size, parentPos);
518 parentPos = builder.
create<arith::AddIOp>(loc, mult, coords[lvl]);
522 if (!stt.isDenseLvl(lvlRank - 1))
524 std::nullopt, value);
530 std::string getMangledFuncName() {
533 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
536 llvm::raw_svector_ostream nameOstream(nameBuffer);
537 nameOstream << kInsertFuncNamePrefix;
538 const Level lvlRank = stt.getLvlRank();
539 for (
Level l = 0; l < lvlRank; l++) {
543 lvlType.begin(), lvlType.end(),
544 [](
char c) { return c ==
'(' || c ==
','; },
'_');
545 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
546 nameOstream << lvlType <<
"_";
551 for (
const auto sz : stt.getDimShape())
552 nameOstream << sz <<
"_";
554 if (!stt.isIdentity())
555 nameOstream << stt.getDimToLvl() <<
"_";
556 nameOstream << stt.getElementType() <<
"_";
557 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
558 return nameOstream.str().str();
570 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
586 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
595 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
601 auto newCall = rewriter.
create<func::CallOp>(loc, op.getCallee(),
602 finalRetTy, flattened);
607 unsigned retOffset = 0;
611 for (
auto ret : op.getResults()) {
612 assert(retOffset < newCall.getNumResults());
613 auto retType = ret.getType();
614 if (failed(typeConverter->convertType(retType, sparseFlat)))
615 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
618 assert(!sparseFlat.empty());
619 if (sparseFlat.size() > 1) {
620 auto flatSize = sparseFlat.size();
621 packedResultVals.emplace_back();
622 llvm::append_range(packedResultVals.back(),
623 newCall.getResults().slice(retOffset, flatSize));
624 retOffset += flatSize;
627 packedResultVals.emplace_back();
628 packedResultVals.back().push_back(newCall.getResult(retOffset));
634 assert(packedResultVals.size() == op.getNumResults());
636 op, llvm::to_vector_of<ValueRange>(packedResultVals));
646 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
648 std::optional<int64_t> lvl = op.getConstantLvlIndex();
649 RankedTensorType srcType = op.getSource().getType();
654 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
665 matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
680 op.getInputCoo().getType());
696 rewriter.
replaceOp(op, adaptor.getInputCoo());
701 template <
typename Op, StorageSpecifierKind kind>
706 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
710 op.getSlice().getType());
712 op.getDim().getZExtValue());
724 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
729 if (!encDst || encDst != encSrc)
731 rewriter.
replaceOp(op, adaptor.getOperands());
740 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
743 rewriter.
replaceOp(op, adaptor.getSource());
749 class SparseTensorAllocConverter
753 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
756 enableBufferInitialization(enableInit) {}
759 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
762 if (!resType.hasEncoding())
769 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
774 auto memrefTp = cast<MemRefType>(field.getType());
775 auto size = rewriter.
create<memref::DimOp>(loc, field, 0);
778 rewriter.
create<memref::CopyOp>(loc, field, copied);
779 fields.push_back(copied);
788 if (!resType.isIdentity()) {
790 op,
"try run --sparse-reinterpret-map before codegen");
798 Value sizeHint = op.getSizeHint();
801 sizeHint, lvlSizesValues, fields);
809 bool enableBufferInitialization;
816 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
819 enableBufferInitialization(enableInit) {}
822 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
825 if (!resType.hasEncoding())
828 if (!resType.isIdentity()) {
830 op,
"try run --sparse-reinterpret-map before codegen");
842 sizeHint, lvlSizesValues, fields);
850 bool enableBufferInitialization;
854 class SparseTensorDeallocConverter
858 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
861 createDeallocs(createDeallocs) {}
864 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
872 if (createDeallocs) {
877 cast<RankedTensorType>(op.getTensor().getType()));
880 rewriter.
create<memref::DeallocOp>(loc, input);
887 const bool createDeallocs;
895 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
899 op.getTensor().getType());
901 if (op.getHasInserts())
914 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
920 op.getTensor().getType());
922 Type eltType = srcType.getElementType();
930 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
932 const auto genAlloc = [&](
Type t) {
939 Value values = genAlloc(eltType);
940 Value filled = genAlloc(boolType);
941 Value added = genAlloc(idxType);
948 rewriter.
create<linalg::FillOp>(
951 rewriter.
create<linalg::FillOp>(
955 assert(op.getNumResults() == 4);
956 rewriter.
replaceOp(op, {values, filled, added, zero});
966 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
971 op.getTensor().getType());
972 Value values = adaptor.getValues();
973 Value filled = adaptor.getFilled();
974 Value added = adaptor.getAdded();
975 Value count = adaptor.getCount();
977 Type eltType = dstType.getElementType();
981 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
984 rewriter.
getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
1001 Value i = loop.getInductionVar();
1007 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
1008 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
1009 params.push_back(crd);
1010 params.push_back(value);
1011 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1016 rewriter.
create<scf::YieldOp>(loc, insertRet);
1022 rewriter.
create<memref::DeallocOp>(loc, values);
1023 rewriter.
create<memref::DeallocOp>(loc, filled);
1024 rewriter.
create<memref::DeallocOp>(loc, added);
1036 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
1039 if (!stt.hasEncoding())
1041 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1048 params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
1049 params.push_back(adaptor.getScalar());
1050 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1062 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1065 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1071 Level lvl = op.getLevel();
1073 op.getTensor().getType());
1082 class SparseToCoordinatesConverter
1085 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1088 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1094 Level lvl = op.getLevel();
1096 op.getTensor().getType());
1097 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1108 class SparseToCoordinatesBufferConverter
1111 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1114 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1122 op.getTensor().getType());
1133 using OpAdaptor =
typename ToValuesOp::Adaptor;
1136 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1143 op.getTensor().getType());
1156 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1159 SparseTensorEncodingAttr encSrc =
1163 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1167 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1172 Type retElemTp = op.getResult().getType().getElementType();
1173 Type srcElemTp = op.getSource().getType().getElementType();
1175 if (retElemTp == srcElemTp && encDst == encSrc) {
1176 rewriter.
replaceOp(op, adaptor.getSource());
1190 op.getSource().getType());
1194 [&rewriter, &fields, srcDesc,
1198 if (fKind == SparseTensorFieldKind::StorageSpec) {
1199 fields.push_back(srcDesc.getSpecifier());
1202 Value srcMem = srcDesc.getMemRefField(fIdx);
1206 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1207 auto dstMem = rewriter.create<memref::AllocOp>(
1208 loc, cast<MemRefType>(fTp), sz);
1209 if (fTp != srcMem.getType()) {
1212 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1213 constantIndex(rewriter, loc, 1),
1214 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1216 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1217 Value casted = genCast(builder, loc, v,
1218 dstMem.getType().getElementType());
1219 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1225 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1227 fields.push_back(dstMem);
1237 class SparseExtractSliceConverter
1242 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1249 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1251 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1255 op.getSource().getType());
1257 auto newSpec = rewriter.
create<StorageSpecifierInitOp>(
1263 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1278 assert(srcEnc.isIdentity());
1294 class SparseNumberOfEntriesConverter
1299 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1305 op.getTensor().getType());
1314 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1323 [&rewriter, &fields, &op, &stt,
1326 assert(fields.size() == fIdx);
1327 if (fKind == SparseTensorFieldKind::StorageSpec) {
1329 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1332 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1334 : op.getLevels()[fIdx];
1337 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1340 mem.getType(), stt.getBatchLvlRank());
1341 mem = rewriter.
create<memref::CastOp>(
1343 rewriter.
create<memref::CollapseShapeOp>(loc, mem, reassoc));
1345 mem = rewriter.
create<memref::CastOp>(loc, fType, mem);
1347 fields.push_back(mem);
1359 Level trailCOOStart = stt.getAoSCOOStart();
1360 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1362 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1363 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1366 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1367 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1370 if (lvl > trailCOOStart)
1376 if (lt.
isa<LevelFormat::Dense>()) {
1377 memSize = rewriter.
create<arith::MulIOp>(loc, lvlSize, memSize);
1378 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1381 if (lt.
isa<LevelFormat::Batch>()) {
1391 memSize = rewriter.
create<arith::MulIOp>(loc, memSize, c2);
1392 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1396 memSize = rewriter.
create<arith::AddIOp>(loc, memSize, c1);
1404 batched.push_back(posBack);
1406 posBack = rewriter.
create<arith::SubIOp>(loc, posBack, c1);
1410 if (lvl == trailCOOStart) {
1425 struct SparseDisassembleOpConverter
1428 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1433 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1436 op.getTensor().getType());
1444 if (fKind == SparseTensorFieldKind::StorageSpec)
1449 if (fKind == SparseTensorFieldKind::ValMemRef) {
1452 dst =
genToMemref(rewriter, loc, op.getOutValues());
1454 retMem.push_back(dst);
1455 Type valLenTp = op.getValLen().getType();
1458 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1459 fKind == SparseTensorFieldKind::CrdMemRef);
1461 sz = fKind == SparseTensorFieldKind::PosMemRef
1465 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1466 retMem.push_back(dst);
1468 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1471 Value flatOut = dst;
1472 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1475 flatOut = rewriter.
create<memref::CollapseShapeOp>(loc, dst, reassoc);
1479 rewriter.
create<memref::CopyOp>(loc, srcMem, dstMem);
1485 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1486 return rewriter.
create<bufferization::ToTensorOp>(loc, v);
1489 retValues.append(retLen.begin(), retLen.end());
1498 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1504 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1519 Value dimSizesBuffer;
1520 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1521 dimSizesValues, dimSizesBuffer);
1526 {indexTp}, {reader}, EmitCInterface::Off)
1531 Value dim2lvlBuffer;
1532 Value lvl2dimBuffer;
1533 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1534 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1537 Value sizeHint = nse;
1540 lvlSizesValues, fields);
1547 const Type elemTp = dstTp.getElementType();
1548 const Type crdTp = dstTp.getCrdType();
1549 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1554 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1560 const Level lvlRank = dstTp.getLvlRank();
1561 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1564 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1566 rewriter.
create<scf::IfOp>(loc, notSorted,
false);
1571 SparseTensorSortKind::HybridQuickSort);
1578 const Type posTp = dstTp.getPosType();
1579 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1580 rewriter.
create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1583 Value coordinatesSize = rewriter.
create<arith::MulIOp>(
1591 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1592 EmitCInterface::Off);
1600 struct SparseHasRuntimeLibraryConverter
1604 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1623 bool createSparseDeallocs,
bool enableBufferInitialization) {
1625 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1626 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1627 SparseCastConverter, SparseExtractSliceConverter,
1628 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1629 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1630 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1631 StorageSpecifierKind::DimOffset>,
1632 SparseSliceGetterOpConverter<ToSliceStrideOp,
1633 StorageSpecifierKind::DimStride>,
1634 SparseToPositionsConverter, SparseToCoordinatesConverter,
1635 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1636 SparseConvertConverter, SparseNewConverter,
1637 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1639 patterns.
add<SparseTensorDeallocConverter>(
1640 typeConverter, patterns.
getContext(), createSparseDeallocs);
1641 patterns.
add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1642 typeConverter, patterns.
getContext(), enableBufferInitialization);
static void flattenOperands(ValueRange operands, SmallVectorImpl< Value > &flattened)
Flattens a list of operands that may contain sparse tensors.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields, RankedTensorType type)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor, RankedTensorType type)
bool isDenseLT(LevelType lt)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.