45 for (
const auto &vals : values)
46 llvm::append_range(result, vals);
53 return builder.
create<memref::LoadOp>(loc, mem, idx);
60 val =
genCast(builder, loc, val,
61 cast<ShapedType>(mem.
getType()).getElementType());
62 builder.
create<memref::StoreOp>(loc, val, mem, idx);
73 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lower, upper, one, fields);
74 for (
unsigned i = 0, e = fields.size(); i < e; i++)
75 fields[i] = forOp.getRegionIterArg(i);
89 auto pushBackOp = builder.
create<PushBackOp>(
91 genCast(builder, loc, value, etp), repeat);
95 pushBackOp.getNewSize());
104 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
115 linear = builder.
create<arith::MulIOp>(loc, linear, two);
128 linear = builder.
create<arith::MulIOp>(loc, linear, size);
133 std::nullopt, valZero, linear);
138 MemRefType memRefType,
Value sz,
140 Value buffer = builder.
create<memref::AllocOp>(loc, memRefType, sz);
141 Type elemType = memRefType.getElementType();
144 builder.
create<linalg::FillOp>(loc, fillValue, buffer);
154 dimSizesValues.clear();
155 dimSizesValues.reserve(dimRank);
158 dimSizesValues.push_back(ShapedType::isDynamic(sz)
177 Value posHeuristic, crdHeuristic, valHeuristic;
179 valHeuristic = lvlSizesValues[0];
180 for (
Level lvl = 1; lvl < lvlRank; lvl++)
182 builder.
create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
183 }
else if (sizeHint) {
186 crdHeuristic = builder.
create<arith::MulIOp>(
189 posHeuristic = builder.
create<arith::AddIOp>(
191 crdHeuristic = sizeHint;
193 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
195 valHeuristic = sizeHint;
197 posHeuristic = crdHeuristic = valHeuristic =
204 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
207 assert(fields.size() == fIdx);
215 posHeuristic, enableInit);
219 crdHeuristic, enableInit);
223 valHeuristic, enableInit);
227 fields.push_back(field);
236 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
237 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
273 assert(lvl < lvlRank &&
"Level is out of bounds");
274 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
275 "Level-rank mismatch");
283 const Value pp1 = builder.
create<arith::AddIOp>(loc, parentPos, one);
285 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
286 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
288 const Value crdStrideC =
291 crdStrideC ? builder.
create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
294 loc,
genCast(builder, loc, pstop, indexType), one);
296 Value lt = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
298 types.push_back(boolType);
299 scf::IfOp ifOp1 = builder.
create<scf::IfOp>(loc, types, lt,
true);
304 crdStrideC ? builder.
create<arith::MulIOp>(loc, plast, crdStrideC)
307 loc, arith::CmpIPredicate::eq,
genCast(builder, loc, crd, indexType),
309 builder.
create<scf::YieldOp>(loc, eq);
312 genStore(builder, loc, msz, positionsAtLvl, parentPos);
320 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
322 types.push_back(indexType);
325 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(loc, types, p,
true);
337 Value mszp1 = builder.
create<arith::AddIOp>(loc, msz, one);
338 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
342 if ((lvl + 1) < lvlRank)
352 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
353 desc.
setField(i, ifOp2.getResult(o++));
354 return ifOp2.getResult(o);
362 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
379 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
380 Value i = loop.getInductionVar();
381 Value oldv = loop.getRegionIterArg(0);
385 loc, arith::CmpIPredicate::eq, newv, posZero);
389 genStore(builder, loc, oldv, posMemRef, i);
390 builder.
create<scf::YieldOp>(loc, oldv);
392 builder.
create<scf::YieldOp>(loc, newv);
394 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
407 auto memTp = llvm::cast<MemRefType>(mem.
getType());
411 if (memTp.getRank() > 1)
415 .
create<memref::SubViewOp>(
430 for (
unsigned i = 0; i < batchLvls; i++)
433 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
434 ret.back().push_back(i);
446 class SparseInsertGenerator
466 const Level lvlRank = stt.getLvlRank();
471 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
472 Value value = args.back();
475 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
476 const auto lt = stt.getLvlType(lvl);
487 parentPos = builder.
create<arith::MulIOp>(loc, parentPos, two);
490 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
504 Value mult = builder.
create<arith::MulIOp>(loc, size, parentPos);
505 parentPos = builder.
create<arith::AddIOp>(loc, mult, coords[lvl]);
509 if (!stt.isDenseLvl(lvlRank - 1))
511 std::nullopt, value);
517 std::string getMangledFuncName() {
520 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
523 llvm::raw_svector_ostream nameOstream(nameBuffer);
524 nameOstream << kInsertFuncNamePrefix;
525 const Level lvlRank = stt.getLvlRank();
526 for (
Level l = 0; l < lvlRank; l++) {
530 lvlType.begin(), lvlType.end(),
531 [](
char c) { return c ==
'(' || c ==
','; },
'_');
532 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
533 nameOstream << lvlType <<
"_";
538 for (
const auto sz : stt.getDimShape())
539 nameOstream << sz <<
"_";
541 if (!stt.isIdentity())
542 nameOstream << stt.getDimToLvl() <<
"_";
543 nameOstream << stt.getElementType() <<
"_";
544 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
545 return nameOstream.str().str();
557 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
572 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
581 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
585 auto newCall = rewriter.
create<func::CallOp>(
586 loc, op.getCallee(), finalRetTy,
flattenValues(adaptor.getOperands()));
591 unsigned retOffset = 0;
595 for (
auto ret : op.getResults()) {
596 assert(retOffset < newCall.getNumResults());
597 auto retType = ret.getType();
598 if (failed(typeConverter->convertType(retType, sparseFlat)))
599 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
602 assert(!sparseFlat.empty());
603 if (sparseFlat.size() > 1) {
604 auto flatSize = sparseFlat.size();
605 packedResultVals.emplace_back();
606 llvm::append_range(packedResultVals.back(),
607 newCall.getResults().slice(retOffset, flatSize));
608 retOffset += flatSize;
611 packedResultVals.emplace_back();
612 packedResultVals.back().push_back(newCall.getResult(retOffset));
618 assert(packedResultVals.size() == op.getNumResults());
629 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
631 std::optional<int64_t> lvl = op.getConstantLvlIndex();
632 RankedTensorType srcType = op.getSource().getType();
637 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
648 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
663 op.getInputCoo().getType());
684 template <
typename Op, StorageSpecifierKind kind>
691 matchAndRewrite(
Op op, OneToNOpAdaptor adaptor,
695 op.getSlice().getType());
697 op.getDim().getZExtValue());
709 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
714 if (!encDst || encDst != encSrc)
725 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
734 class SparseTensorAllocConverter
738 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
741 enableBufferInitialization(enableInit) {}
744 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
747 if (!resType.hasEncoding())
754 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
759 auto memrefTp = cast<MemRefType>(field.getType());
760 auto size = rewriter.
create<memref::DimOp>(loc, field, 0);
763 rewriter.
create<memref::CopyOp>(loc, field, copied);
764 fields.push_back(copied);
773 if (!resType.isIdentity()) {
775 op,
"try run --sparse-reinterpret-map before codegen");
784 Value sizeHint = op.getSizeHint();
787 sizeHint, lvlSizesValues, fields);
795 bool enableBufferInitialization;
802 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
805 enableBufferInitialization(enableInit) {}
808 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
811 if (!resType.hasEncoding())
814 if (!resType.isIdentity()) {
816 op,
"try run --sparse-reinterpret-map before codegen");
828 sizeHint, lvlSizesValues, fields);
836 bool enableBufferInitialization;
840 class SparseTensorDeallocConverter
844 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
847 createDeallocs(createDeallocs) {}
850 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
858 if (createDeallocs) {
863 cast<RankedTensorType>(op.getTensor().getType()));
866 rewriter.
create<memref::DeallocOp>(loc, input);
873 const bool createDeallocs;
881 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
885 op.getTensor().getType());
887 if (op.getHasInserts())
900 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
906 op.getTensor().getType());
908 Type eltType = srcType.getElementType();
916 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
918 const auto genAlloc = [&](
Type t) {
925 Value values = genAlloc(eltType);
926 Value filled = genAlloc(boolType);
927 Value added = genAlloc(idxType);
934 rewriter.
create<linalg::FillOp>(
937 rewriter.
create<linalg::FillOp>(
941 assert(op.getNumResults() == 4);
942 rewriter.
replaceOp(op, {values, filled, added, zero});
952 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
957 op.getTensor().getType());
958 Value values = llvm::getSingleElement(adaptor.getValues());
959 Value filled = llvm::getSingleElement(adaptor.getFilled());
960 Value added = llvm::getSingleElement(adaptor.getAdded());
961 Value count = llvm::getSingleElement(adaptor.getCount());
963 Type eltType = dstType.getElementType();
967 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
970 rewriter.
getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
987 Value i = loop.getInductionVar();
993 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
995 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
996 params.push_back(crd);
997 params.push_back(value);
998 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1003 rewriter.
create<scf::YieldOp>(loc, insertRet);
1009 rewriter.
create<memref::DeallocOp>(loc, values);
1010 rewriter.
create<memref::DeallocOp>(loc, filled);
1011 rewriter.
create<memref::DeallocOp>(loc, added);
1023 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1026 if (!stt.hasEncoding())
1028 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1036 params.append(flatIndices.begin(), flatIndices.end());
1037 params.push_back(llvm::getSingleElement(adaptor.getScalar()));
1038 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1050 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1053 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1059 Level lvl = op.getLevel();
1061 op.getTensor().getType());
1070 class SparseToCoordinatesConverter
1073 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1076 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1082 Level lvl = op.getLevel();
1084 op.getTensor().getType());
1085 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1096 class SparseToCoordinatesBufferConverter
1099 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1102 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1110 op.getTensor().getType());
1121 using OpAdaptor =
typename ToValuesOp::Adaptor;
1124 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1131 op.getTensor().getType());
1144 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1147 SparseTensorEncodingAttr encSrc =
1151 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1155 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1160 Type retElemTp = op.getResult().getType().getElementType();
1161 Type srcElemTp = op.getSource().getType().getElementType();
1163 if (retElemTp == srcElemTp && encDst == encSrc) {
1178 op.getSource().getType());
1182 [&rewriter, &fields, srcDesc,
1186 if (fKind == SparseTensorFieldKind::StorageSpec) {
1187 fields.push_back(srcDesc.getSpecifier());
1190 Value srcMem = srcDesc.getMemRefField(fIdx);
1194 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1195 auto dstMem = rewriter.create<memref::AllocOp>(
1196 loc, cast<MemRefType>(fTp), sz);
1197 if (fTp != srcMem.getType()) {
1200 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1201 constantIndex(rewriter, loc, 1),
1202 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1204 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1205 Value casted = genCast(builder, loc, v,
1206 dstMem.getType().getElementType());
1207 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1213 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1215 fields.push_back(dstMem);
1225 class SparseExtractSliceConverter
1230 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1237 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1239 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1243 op.getSource().getType());
1245 auto newSpec = rewriter.
create<StorageSpecifierInitOp>(
1251 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1266 assert(srcEnc.isIdentity());
1282 class SparseNumberOfEntriesConverter
1287 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1293 op.getTensor().getType());
1302 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1311 [&rewriter, &fields, &op, &stt,
1314 assert(fields.size() == fIdx);
1315 if (fKind == SparseTensorFieldKind::StorageSpec) {
1317 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1320 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1322 : op.getLevels()[fIdx];
1325 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1328 mem.getType(), stt.getBatchLvlRank());
1329 mem = rewriter.
create<memref::CastOp>(
1331 rewriter.
create<memref::CollapseShapeOp>(loc, mem, reassoc));
1333 mem = rewriter.
create<memref::CastOp>(loc, fType, mem);
1335 fields.push_back(mem);
1347 Level trailCOOStart = stt.getAoSCOOStart();
1348 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1350 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1351 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1354 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1355 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1358 if (lvl > trailCOOStart)
1364 if (lt.
isa<LevelFormat::Dense>()) {
1365 memSize = rewriter.
create<arith::MulIOp>(loc, lvlSize, memSize);
1366 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1369 if (lt.
isa<LevelFormat::Batch>()) {
1379 memSize = rewriter.
create<arith::MulIOp>(loc, memSize, c2);
1380 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1384 memSize = rewriter.
create<arith::AddIOp>(loc, memSize, c1);
1392 batched.push_back(posBack);
1394 posBack = rewriter.
create<arith::SubIOp>(loc, posBack, c1);
1398 if (lvl == trailCOOStart) {
1413 struct SparseDisassembleOpConverter
1416 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1421 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1424 op.getTensor().getType());
1432 if (fKind == SparseTensorFieldKind::StorageSpec)
1437 if (fKind == SparseTensorFieldKind::ValMemRef) {
1440 dst =
genToMemref(rewriter, loc, op.getOutValues());
1442 retMem.push_back(dst);
1443 Type valLenTp = op.getValLen().getType();
1446 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1447 fKind == SparseTensorFieldKind::CrdMemRef);
1449 sz = fKind == SparseTensorFieldKind::PosMemRef
1453 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1454 retMem.push_back(dst);
1456 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1459 Value flatOut = dst;
1460 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1463 flatOut = rewriter.
create<memref::CollapseShapeOp>(loc, dst, reassoc);
1467 rewriter.
create<memref::CopyOp>(loc, srcMem, dstMem);
1473 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1474 return rewriter.
create<bufferization::ToTensorOp>(loc, v);
1477 retValues.append(retLen.begin(), retLen.end());
1486 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1492 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1507 Value dimSizesBuffer;
1508 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1509 dimSizesValues, dimSizesBuffer);
1514 {indexTp}, {reader}, EmitCInterface::Off)
1519 Value dim2lvlBuffer;
1520 Value lvl2dimBuffer;
1521 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1522 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1525 Value sizeHint = nse;
1528 lvlSizesValues, fields);
1535 const Type elemTp = dstTp.getElementType();
1536 const Type crdTp = dstTp.getCrdType();
1537 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1542 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1548 const Level lvlRank = dstTp.getLvlRank();
1549 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1552 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1554 rewriter.
create<scf::IfOp>(loc, notSorted,
false);
1559 SparseTensorSortKind::HybridQuickSort);
1566 const Type posTp = dstTp.getPosType();
1567 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1568 rewriter.
create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1571 Value coordinatesSize = rewriter.
create<arith::MulIOp>(
1579 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1580 EmitCInterface::Off);
1588 struct SparseHasRuntimeLibraryConverter
1592 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1611 bool createSparseDeallocs,
bool enableBufferInitialization) {
1613 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1614 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1615 SparseCastConverter, SparseExtractSliceConverter,
1616 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1617 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1618 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1619 StorageSpecifierKind::DimOffset>,
1620 SparseSliceGetterOpConverter<ToSliceStrideOp,
1621 StorageSpecifierKind::DimStride>,
1622 SparseToPositionsConverter, SparseToCoordinatesConverter,
1623 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1624 SparseConvertConverter, SparseNewConverter,
1625 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1626 typeConverter,
patterns.getContext());
1627 patterns.add<SparseTensorDeallocConverter>(
1628 typeConverter,
patterns.getContext(), createSparseDeallocs);
1629 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1630 typeConverter,
patterns.getContext(), enableBufferInitialization);
union mlir::linalg::@1197::ArityGroupAndKind::Kind kind
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.