45 for (
const auto &vals : values)
46 llvm::append_range(result, vals);
52 assert(values.size() == 1 &&
"expected single value");
53 return values.front();
59 return builder.
create<memref::LoadOp>(loc, mem, idx);
66 val =
genCast(builder, loc, val,
67 cast<ShapedType>(mem.
getType()).getElementType());
68 builder.
create<memref::StoreOp>(loc, val, mem, idx);
79 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lower, upper, one, fields);
80 for (
unsigned i = 0, e = fields.size(); i < e; i++)
81 fields[i] = forOp.getRegionIterArg(i);
95 auto pushBackOp = builder.
create<PushBackOp>(
97 genCast(builder, loc, value, etp), repeat);
101 pushBackOp.getNewSize());
110 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
121 linear = builder.
create<arith::MulIOp>(loc, linear, two);
134 linear = builder.
create<arith::MulIOp>(loc, linear, size);
139 std::nullopt, valZero, linear);
144 MemRefType memRefType,
Value sz,
146 Value buffer = builder.
create<memref::AllocOp>(loc, memRefType, sz);
147 Type elemType = memRefType.getElementType();
150 builder.
create<linalg::FillOp>(loc, fillValue, buffer);
160 dimSizesValues.clear();
161 dimSizesValues.reserve(dimRank);
164 dimSizesValues.push_back(ShapedType::isDynamic(sz)
183 Value posHeuristic, crdHeuristic, valHeuristic;
185 valHeuristic = lvlSizesValues[0];
186 for (
Level lvl = 1; lvl < lvlRank; lvl++)
188 builder.
create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
189 }
else if (sizeHint) {
192 crdHeuristic = builder.
create<arith::MulIOp>(
195 posHeuristic = builder.
create<arith::AddIOp>(
197 crdHeuristic = sizeHint;
199 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
201 valHeuristic = sizeHint;
203 posHeuristic = crdHeuristic = valHeuristic =
210 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
213 assert(fields.size() == fIdx);
221 posHeuristic, enableInit);
225 crdHeuristic, enableInit);
229 valHeuristic, enableInit);
233 fields.push_back(field);
242 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
243 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
279 assert(lvl < lvlRank &&
"Level is out of bounds");
280 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
281 "Level-rank mismatch");
289 const Value pp1 = builder.
create<arith::AddIOp>(loc, parentPos, one);
291 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
292 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
294 const Value crdStrideC =
297 crdStrideC ? builder.
create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
300 loc,
genCast(builder, loc, pstop, indexType), one);
302 Value lt = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
304 types.push_back(boolType);
305 scf::IfOp ifOp1 = builder.
create<scf::IfOp>(loc, types, lt,
true);
310 crdStrideC ? builder.
create<arith::MulIOp>(loc, plast, crdStrideC)
313 loc, arith::CmpIPredicate::eq,
genCast(builder, loc, crd, indexType),
315 builder.
create<scf::YieldOp>(loc, eq);
318 genStore(builder, loc, msz, positionsAtLvl, parentPos);
326 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
328 types.push_back(indexType);
331 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(loc, types, p,
true);
343 Value mszp1 = builder.
create<arith::AddIOp>(loc, msz, one);
344 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
348 if ((lvl + 1) < lvlRank)
358 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
359 desc.
setField(i, ifOp2.getResult(o++));
360 return ifOp2.getResult(o);
368 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
385 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
386 Value i = loop.getInductionVar();
387 Value oldv = loop.getRegionIterArg(0);
391 loc, arith::CmpIPredicate::eq, newv, posZero);
395 genStore(builder, loc, oldv, posMemRef, i);
396 builder.
create<scf::YieldOp>(loc, oldv);
398 builder.
create<scf::YieldOp>(loc, newv);
400 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
413 auto memTp = llvm::cast<MemRefType>(mem.
getType());
417 if (memTp.getRank() > 1)
421 .
create<memref::SubViewOp>(
436 for (
unsigned i = 0; i < batchLvls; i++)
439 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
440 ret.back().push_back(i);
452 class SparseInsertGenerator
472 const Level lvlRank = stt.getLvlRank();
477 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
478 Value value = args.back();
481 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
482 const auto lt = stt.getLvlType(lvl);
493 parentPos = builder.
create<arith::MulIOp>(loc, parentPos, two);
496 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
510 Value mult = builder.
create<arith::MulIOp>(loc, size, parentPos);
511 parentPos = builder.
create<arith::AddIOp>(loc, mult, coords[lvl]);
515 if (!stt.isDenseLvl(lvlRank - 1))
517 std::nullopt, value);
523 std::string getMangledFuncName() {
526 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
529 llvm::raw_svector_ostream nameOstream(nameBuffer);
530 nameOstream << kInsertFuncNamePrefix;
531 const Level lvlRank = stt.getLvlRank();
532 for (
Level l = 0; l < lvlRank; l++) {
536 lvlType.begin(), lvlType.end(),
537 [](
char c) { return c ==
'(' || c ==
','; },
'_');
538 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
539 nameOstream << lvlType <<
"_";
544 for (
const auto sz : stt.getDimShape())
545 nameOstream << sz <<
"_";
547 if (!stt.isIdentity())
548 nameOstream << stt.getDimToLvl() <<
"_";
549 nameOstream << stt.getElementType() <<
"_";
550 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
551 return nameOstream.str().str();
563 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
578 matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
587 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
591 auto newCall = rewriter.
create<func::CallOp>(
592 loc, op.getCallee(), finalRetTy,
flattenValues(adaptor.getOperands()));
597 unsigned retOffset = 0;
601 for (
auto ret : op.getResults()) {
602 assert(retOffset < newCall.getNumResults());
603 auto retType = ret.getType();
604 if (failed(typeConverter->convertType(retType, sparseFlat)))
605 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
608 assert(!sparseFlat.empty());
609 if (sparseFlat.size() > 1) {
610 auto flatSize = sparseFlat.size();
611 packedResultVals.emplace_back();
612 llvm::append_range(packedResultVals.back(),
613 newCall.getResults().slice(retOffset, flatSize));
614 retOffset += flatSize;
617 packedResultVals.emplace_back();
618 packedResultVals.back().push_back(newCall.getResult(retOffset));
624 assert(packedResultVals.size() == op.getNumResults());
626 op, llvm::to_vector_of<ValueRange>(packedResultVals));
636 matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
638 std::optional<int64_t> lvl = op.getConstantLvlIndex();
639 RankedTensorType srcType = op.getSource().getType();
644 auto sz = desc.
getLvlSize(rewriter, op.getLoc(), *lvl);
655 matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
670 op.getInputCoo().getType());
691 template <
typename Op, StorageSpecifierKind kind>
698 matchAndRewrite(
Op op, OneToNOpAdaptor adaptor,
702 op.getSlice().getType());
704 op.getDim().getZExtValue());
716 matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
721 if (!encDst || encDst != encSrc)
732 matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
741 class SparseTensorAllocConverter
745 SparseTensorAllocConverter(
const TypeConverter &typeConverter,
748 enableBufferInitialization(enableInit) {}
751 matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
754 if (!resType.hasEncoding())
761 adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
766 auto memrefTp = cast<MemRefType>(field.getType());
767 auto size = rewriter.
create<memref::DimOp>(loc, field, 0);
770 rewriter.
create<memref::CopyOp>(loc, field, copied);
771 fields.push_back(copied);
780 if (!resType.isIdentity()) {
782 op,
"try run --sparse-reinterpret-map before codegen");
791 Value sizeHint = op.getSizeHint();
794 sizeHint, lvlSizesValues, fields);
802 bool enableBufferInitialization;
809 SparseTensorEmptyConverter(
const TypeConverter &typeConverter,
812 enableBufferInitialization(enableInit) {}
815 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
818 if (!resType.hasEncoding())
821 if (!resType.isIdentity()) {
823 op,
"try run --sparse-reinterpret-map before codegen");
835 sizeHint, lvlSizesValues, fields);
843 bool enableBufferInitialization;
847 class SparseTensorDeallocConverter
851 SparseTensorDeallocConverter(
const TypeConverter &typeConverter,
854 createDeallocs(createDeallocs) {}
857 matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
865 if (createDeallocs) {
870 cast<RankedTensorType>(op.getTensor().getType()));
873 rewriter.
create<memref::DeallocOp>(loc, input);
880 const bool createDeallocs;
888 matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
892 op.getTensor().getType());
894 if (op.getHasInserts())
907 matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
913 op.getTensor().getType());
915 Type eltType = srcType.getElementType();
923 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
925 const auto genAlloc = [&](
Type t) {
932 Value values = genAlloc(eltType);
933 Value filled = genAlloc(boolType);
934 Value added = genAlloc(idxType);
941 rewriter.
create<linalg::FillOp>(
944 rewriter.
create<linalg::FillOp>(
948 assert(op.getNumResults() == 4);
949 rewriter.
replaceOp(op, {values, filled, added, zero});
959 matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
964 op.getTensor().getType());
970 Type eltType = dstType.getElementType();
974 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
977 rewriter.
getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
994 Value i = loop.getInductionVar();
1000 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
1002 params.append(flatLvlCoords.begin(), flatLvlCoords.end());
1003 params.push_back(crd);
1004 params.push_back(value);
1005 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1010 rewriter.
create<scf::YieldOp>(loc, insertRet);
1016 rewriter.
create<memref::DeallocOp>(loc, values);
1017 rewriter.
create<memref::DeallocOp>(loc, filled);
1018 rewriter.
create<memref::DeallocOp>(loc, added);
1030 matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
1033 if (!stt.hasEncoding())
1035 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1043 params.append(flatIndices.begin(), flatIndices.end());
1045 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1057 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1060 matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
1066 Level lvl = op.getLevel();
1068 op.getTensor().getType());
1077 class SparseToCoordinatesConverter
1080 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1083 matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
1089 Level lvl = op.getLevel();
1091 op.getTensor().getType());
1092 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1103 class SparseToCoordinatesBufferConverter
1106 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1109 matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
1117 op.getTensor().getType());
1128 using OpAdaptor =
typename ToValuesOp::Adaptor;
1131 matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
1138 op.getTensor().getType());
1151 matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
1154 SparseTensorEncodingAttr encSrc =
1158 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1162 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1167 Type retElemTp = op.getResult().getType().getElementType();
1168 Type srcElemTp = op.getSource().getType().getElementType();
1170 if (retElemTp == srcElemTp && encDst == encSrc) {
1185 op.getSource().getType());
1189 [&rewriter, &fields, srcDesc,
1193 if (fKind == SparseTensorFieldKind::StorageSpec) {
1194 fields.push_back(srcDesc.getSpecifier());
1197 Value srcMem = srcDesc.getMemRefField(fIdx);
1201 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1202 auto dstMem = rewriter.create<memref::AllocOp>(
1203 loc, cast<MemRefType>(fTp), sz);
1204 if (fTp != srcMem.getType()) {
1207 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1208 constantIndex(rewriter, loc, 1),
1209 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1211 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1212 Value casted = genCast(builder, loc, v,
1213 dstMem.getType().getElementType());
1214 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1220 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1222 fields.push_back(dstMem);
1232 class SparseExtractSliceConverter
1237 matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
1244 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1246 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1250 op.getSource().getType());
1252 auto newSpec = rewriter.
create<StorageSpecifierInitOp>(
1258 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1273 assert(srcEnc.isIdentity());
1289 class SparseNumberOfEntriesConverter
1294 matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
1300 op.getTensor().getType());
1309 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1318 [&rewriter, &fields, &op, &stt,
1321 assert(fields.size() == fIdx);
1322 if (fKind == SparseTensorFieldKind::StorageSpec) {
1324 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1327 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1329 : op.getLevels()[fIdx];
1332 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1335 mem.getType(), stt.getBatchLvlRank());
1336 mem = rewriter.
create<memref::CastOp>(
1338 rewriter.
create<memref::CollapseShapeOp>(loc, mem, reassoc));
1340 mem = rewriter.
create<memref::CastOp>(loc, fType, mem);
1342 fields.push_back(mem);
1354 Level trailCOOStart = stt.getAoSCOOStart();
1355 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1357 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1358 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1361 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1362 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1365 if (lvl > trailCOOStart)
1371 if (lt.
isa<LevelFormat::Dense>()) {
1372 memSize = rewriter.
create<arith::MulIOp>(loc, lvlSize, memSize);
1373 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1376 if (lt.
isa<LevelFormat::Batch>()) {
1386 memSize = rewriter.
create<arith::MulIOp>(loc, memSize, c2);
1387 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1391 memSize = rewriter.
create<arith::AddIOp>(loc, memSize, c1);
1399 batched.push_back(posBack);
1401 posBack = rewriter.
create<arith::SubIOp>(loc, posBack, c1);
1405 if (lvl == trailCOOStart) {
1420 struct SparseDisassembleOpConverter
1423 SparseDisassembleOpConverter(
const TypeConverter &typeConverter,
1428 matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
1431 op.getTensor().getType());
1439 if (fKind == SparseTensorFieldKind::StorageSpec)
1444 if (fKind == SparseTensorFieldKind::ValMemRef) {
1447 dst =
genToMemref(rewriter, loc, op.getOutValues());
1449 retMem.push_back(dst);
1450 Type valLenTp = op.getValLen().getType();
1453 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1454 fKind == SparseTensorFieldKind::CrdMemRef);
1456 sz = fKind == SparseTensorFieldKind::PosMemRef
1460 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1461 retMem.push_back(dst);
1463 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1466 Value flatOut = dst;
1467 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1470 flatOut = rewriter.
create<memref::CollapseShapeOp>(loc, dst, reassoc);
1474 rewriter.
create<memref::CopyOp>(loc, srcMem, dstMem);
1480 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1481 return rewriter.
create<bufferization::ToTensorOp>(loc, v);
1484 retValues.append(retLen.begin(), retLen.end());
1493 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1499 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1514 Value dimSizesBuffer;
1515 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1516 dimSizesValues, dimSizesBuffer);
1521 {indexTp}, {reader}, EmitCInterface::Off)
1526 Value dim2lvlBuffer;
1527 Value lvl2dimBuffer;
1528 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1529 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1532 Value sizeHint = nse;
1535 lvlSizesValues, fields);
1542 const Type elemTp = dstTp.getElementType();
1543 const Type crdTp = dstTp.getCrdType();
1544 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1549 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1555 const Level lvlRank = dstTp.getLvlRank();
1556 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1559 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1561 rewriter.
create<scf::IfOp>(loc, notSorted,
false);
1566 SparseTensorSortKind::HybridQuickSort);
1573 const Type posTp = dstTp.getPosType();
1574 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1575 rewriter.
create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1578 Value coordinatesSize = rewriter.
create<arith::MulIOp>(
1586 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1587 EmitCInterface::Off);
1595 struct SparseHasRuntimeLibraryConverter
1599 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1618 bool createSparseDeallocs,
bool enableBufferInitialization) {
1620 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1621 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1622 SparseCastConverter, SparseExtractSliceConverter,
1623 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1624 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1625 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1626 StorageSpecifierKind::DimOffset>,
1627 SparseSliceGetterOpConverter<ToSliceStrideOp,
1628 StorageSpecifierKind::DimStride>,
1629 SparseToPositionsConverter, SparseToCoordinatesConverter,
1630 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1631 SparseConvertConverter, SparseNewConverter,
1632 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1633 typeConverter,
patterns.getContext());
1634 patterns.add<SparseTensorDeallocConverter>(
1635 typeConverter,
patterns.getContext(), createSparseDeallocs);
1636 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1637 typeConverter,
patterns.getContext(), enableBufferInitialization);
static Value getSingleValue(ValueRange values)
Assert that the given value range contains a single value and return it.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SparseTensorDescriptor getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
bool isLooseCompressedLT(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
bool isDenseLT(LevelType lt)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(ValueRange adaptorValues, SmallVectorImpl< Value > &fields, RankedTensorType type)
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
Include the generated interface declarations.
void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.