49 for (
auto operand : operands) {
56 flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
58 flattened.push_back(operand);
66 return builder.
create<memref::LoadOp>(loc, mem, idx);
73 val =
genCast(builder, loc, val,
74 cast<ShapedType>(mem.
getType()).getElementType());
75 builder.
create<memref::StoreOp>(loc, val, mem, idx);
86 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lower, upper, one, fields);
87 for (
unsigned i = 0, e = fields.size(); i < e; i++)
88 fields[i] = forOp.getRegionIterArg(i);
102 auto pushBackOp = builder.
create<PushBackOp>(
104 genCast(builder, loc, value, etp), repeat);
108 pushBackOp.getNewSize());
117 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
128 linear = builder.
create<arith::MulIOp>(loc, linear, two);
141 linear = builder.
create<arith::MulIOp>(loc, linear, size);
146 std::nullopt, valZero, linear);
151 MemRefType memRefType,
Value sz,
153 Value buffer = builder.
create<memref::AllocOp>(loc, memRefType, sz);
154 Type elemType = memRefType.getElementType();
157 builder.
create<linalg::FillOp>(loc, fillValue, buffer);
167 dimSizesValues.clear();
168 dimSizesValues.reserve(dimRank);
171 dimSizesValues.push_back(ShapedType::isDynamic(sz)
190 Value posHeuristic, crdHeuristic, valHeuristic;
192 valHeuristic = lvlSizesValues[0];
193 for (
Level lvl = 1; lvl < lvlRank; lvl++)
195 builder.
create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196 }
else if (sizeHint) {
199 crdHeuristic = builder.
create<arith::MulIOp>(
202 posHeuristic = builder.
create<arith::AddIOp>(
204 crdHeuristic = sizeHint;
206 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
208 valHeuristic = sizeHint;
210 posHeuristic = crdHeuristic = valHeuristic =
217 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
220 assert(fields.size() == fIdx);
228 posHeuristic, enableInit);
232 crdHeuristic, enableInit);
236 valHeuristic, enableInit);
240 fields.push_back(field);
249 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
250 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
286 assert(lvl < lvlRank &&
"Level is out of bounds");
287 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
288 "Level-rank mismatch");
296 const Value pp1 = builder.
create<arith::AddIOp>(loc, parentPos, one);
298 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
299 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
301 const Value crdStrideC =
304 crdStrideC ? builder.
create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
307 loc,
genCast(builder, loc, pstop, indexType), one);
309 Value lt = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
311 types.push_back(boolType);
312 scf::IfOp ifOp1 = builder.
create<scf::IfOp>(loc, types, lt,
true);
317 crdStrideC ? builder.
create<arith::MulIOp>(loc, plast, crdStrideC)
320 loc, arith::CmpIPredicate::eq,
genCast(builder, loc, crd, indexType),
322 builder.
create<scf::YieldOp>(loc, eq);
325 genStore(builder, loc, msz, positionsAtLvl, parentPos);
333 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
335 types.push_back(indexType);
338 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(loc, types, p,
true);
350 Value mszp1 = builder.
create<arith::AddIOp>(loc, msz, one);
351 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
355 if ((lvl + 1) < lvlRank)
365 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
366 desc.
setField(i, ifOp2.getResult(o++));
367 return ifOp2.getResult(o);
375 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
392 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
393 Value i = loop.getInductionVar();
394 Value oldv = loop.getRegionIterArg(0);
398 loc, arith::CmpIPredicate::eq, newv, posZero);
402 genStore(builder, loc, oldv, posMemRef, i);
403 builder.
create<scf::YieldOp>(loc, oldv);
405 builder.
create<scf::YieldOp>(loc, newv);
407 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
420 auto memTp = llvm::cast<MemRefType>(mem.
getType());
424 if (memTp.getRank() > 1)
428 .
create<memref::SubViewOp>(
443 for (
unsigned i = 0; i < batchLvls; i++)
446 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
447 ret.back().push_back(i);
459 class SparseInsertGenerator
479 const Level lvlRank = stt.getLvlRank();
484 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
485 Value value = args.back();
488 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
489 const auto lt = stt.getLvlType(lvl);
500 parentPos = builder.
create<arith::MulIOp>(loc, parentPos, two);
503 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
517 Value mult = builder.
create<arith::MulIOp>(loc, size, parentPos);
518 parentPos = builder.
create<arith::AddIOp>(loc, mult, coords[lvl]);
522 if (!stt.isDenseLvl(lvlRank - 1))
524 std::nullopt, value);
530 std::string getMangledFuncName() {
533 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
536 llvm::raw_svector_ostream nameOstream(nameBuffer);
537 nameOstream << kInsertFuncNamePrefix;
538 const Level lvlRank = stt.getLvlRank();
539 for (
Level l = 0; l < lvlRank; l++) {
543 lvlType.begin(), lvlType.end(),
544 [](
char c) { return c ==
'(' || c ==
','; },
'_');
545 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
546 nameOstream << lvlType <<
"_";
551 for (
const auto sz : stt.getDimShape())
552 nameOstream << sz <<
"_";
554 if (!stt.isIdentity())
555 nameOstream << stt.getDimToLvl() <<
"_";
556 nameOstream << stt.getElementType() <<
"_";
557 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
558 return nameOstream.str().str();
570 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
586 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
595 if (failed(typeConverter->convertTypes(op.
getResultTypes(), finalRetTy)))
601 auto newCall = rewriter.
create<func::CallOp>(loc, op.getCallee(),
602 finalRetTy, flattened);
607 unsigned retOffset = 0;
612 assert(retOffset < newCall.getNumResults());
613 auto retType = ret.getType();
614 if (failed(typeConverter->convertType(retType, sparseFlat)))
615 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
618 assert(!sparseFlat.empty());
619 if (sparseFlat.size() > 1) {
620 auto flatSize = sparseFlat.size();
622 newCall.result_begin() + retOffset,
623 newCall.result_begin() + retOffset + flatSize));
624 castedRet.push_back(
genTuple(rewriter, loc, retType, fields));
625 retOffset += flatSize;
628 castedRet.push_back(newCall.getResult(retOffset));
645 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
647 std::optional<int64_t> lvl = op.getConstantLvlIndex();
663 matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
693 rewriter.
replaceOp(op, adaptor.getInputCoo());
698 template <
typename Op, StorageSpecifierKind kind>
703 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
708 op.getDim().getZExtValue());
720 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
725 if (!encDst || encDst != encSrc)
727 rewriter.
replaceOp(op, adaptor.getOperands());
736 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
739 rewriter.
replaceOp(op, adaptor.getSource());
745 class SparseTensorAllocConverter
752 enableBufferInitialization(enableInit) {}
755 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
758 if (!resType.hasEncoding())
769 auto memrefTp = cast<MemRefType>(field.getType());
770 auto size = rewriter.
create<memref::DimOp>(loc, field, 0);
773 rewriter.
create<memref::CopyOp>(loc, field, copied);
774 fields.push_back(copied);
783 if (!resType.isIdentity()) {
785 op,
"try run --sparse-reinterpret-map before codegen");
793 Value sizeHint = op.getSizeHint();
796 sizeHint, lvlSizesValues, fields);
804 bool enableBufferInitialization;
814 enableBufferInitialization(enableInit) {}
817 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
820 if (!resType.hasEncoding())
823 if (!resType.isIdentity()) {
825 op,
"try run --sparse-reinterpret-map before codegen");
837 sizeHint, lvlSizesValues, fields);
845 bool enableBufferInitialization;
849 class SparseTensorDeallocConverter
856 createDeallocs(createDeallocs) {}
859 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
867 if (createDeallocs) {
873 rewriter.
create<memref::DeallocOp>(loc, input);
880 const bool createDeallocs;
888 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
893 if (op.getHasInserts())
906 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
913 Type eltType = srcType.getElementType();
921 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
923 const auto genAlloc = [&](
Type t) {
930 Value values = genAlloc(eltType);
931 Value filled = genAlloc(boolType);
932 Value added = genAlloc(idxType);
939 rewriter.
create<linalg::FillOp>(
942 rewriter.
create<linalg::FillOp>(
947 rewriter.
replaceOp(op, {values, filled, added, zero});
957 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
962 Value values = adaptor.getValues();
963 Value filled = adaptor.getFilled();
964 Value added = adaptor.getAdded();
965 Value count = adaptor.getCount();
967 Type eltType = dstType.getElementType();
971 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
974 rewriter.
getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
991 Value i = loop.getInductionVar();
997 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
998 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
999 params.push_back(crd);
1000 params.push_back(value);
1001 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1006 rewriter.
create<scf::YieldOp>(loc, insertRet);
1009 Value result =
genTuple(rewriter, loc, dstType, loop->getResults());
1013 rewriter.
create<memref::DeallocOp>(loc, values);
1014 rewriter.
create<memref::DeallocOp>(loc, filled);
1015 rewriter.
create<memref::DeallocOp>(loc, added);
1027 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
1030 if (!stt.hasEncoding())
1032 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1038 params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
1039 params.push_back(adaptor.getScalar());
1040 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1045 genTuple(rewriter, loc, op.getDest().getType(), ret));
1053 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1056 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1062 Level lvl = op.getLevel();
1072 class SparseToCoordinatesConverter
1075 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1078 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1084 Level lvl = op.getLevel();
1086 auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
1097 class SparseToCoordinatesBufferConverter
1100 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1103 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1121 using OpAdaptor =
typename ToValuesOp::Adaptor;
1124 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1143 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1146 SparseTensorEncodingAttr encSrc =
1150 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1154 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1160 Type srcElemTp = op.getSource().getType().getElementType();
1162 if (retElemTp == srcElemTp && encDst == encSrc) {
1163 rewriter.
replaceOp(op, adaptor.getSource());
1180 [&rewriter, &fields, srcDesc,
1184 if (fKind == SparseTensorFieldKind::StorageSpec) {
1185 fields.push_back(srcDesc.getSpecifier());
1188 Value srcMem = srcDesc.getMemRefField(fIdx);
1192 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1193 auto dstMem = rewriter.create<memref::AllocOp>(
1194 loc, cast<MemRefType>(fTp), sz);
1195 if (fTp != srcMem.getType()) {
1198 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1199 constantIndex(rewriter, loc, 1),
1200 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1202 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1203 Value casted = genCast(builder, loc, v,
1204 dstMem.getType().getElementType());
1205 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1211 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1213 fields.push_back(dstMem);
1224 class SparseExtractSliceConverter
1229 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1236 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1238 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1243 auto newSpec = rewriter.
create<StorageSpecifierInitOp>(
1249 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1264 assert(srcEnc.isIdentity());
1281 class SparseNumberOfEntriesConverter
1286 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1300 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1309 [&rewriter, &fields, &op, &stt,
1312 assert(fields.size() == fIdx);
1313 if (fKind == SparseTensorFieldKind::StorageSpec) {
1315 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1318 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1320 : op.getLevels()[fIdx];
1323 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1326 mem.getType(), stt.getBatchLvlRank());
1327 mem = rewriter.
create<memref::CastOp>(
1329 rewriter.
create<memref::CollapseShapeOp>(loc, mem, reassoc));
1331 mem = rewriter.
create<memref::CastOp>(loc, fType, mem);
1333 fields.push_back(mem);
1345 Level trailCOOStart = stt.getAoSCOOStart();
1346 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1348 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1349 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1352 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1353 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1356 if (lvl > trailCOOStart)
1362 if (lt.
isa<LevelFormat::Dense>()) {
1363 memSize = rewriter.
create<arith::MulIOp>(loc, lvlSize, memSize);
1364 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1367 if (lt.
isa<LevelFormat::Batch>()) {
1377 memSize = rewriter.
create<arith::MulIOp>(loc, memSize, c2);
1378 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1382 memSize = rewriter.
create<arith::AddIOp>(loc, memSize, c1);
1390 batched.push_back(posBack);
1392 posBack = rewriter.
create<arith::SubIOp>(loc, posBack, c1);
1396 if (lvl == trailCOOStart) {
1411 struct SparseDisassembleOpConverter
1419 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1429 if (fKind == SparseTensorFieldKind::StorageSpec)
1434 if (fKind == SparseTensorFieldKind::ValMemRef) {
1437 dst =
genToMemref(rewriter, loc, op.getOutValues());
1439 retMem.push_back(dst);
1440 Type valLenTp = op.getValLen().getType();
1443 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1444 fKind == SparseTensorFieldKind::CrdMemRef);
1446 sz = fKind == SparseTensorFieldKind::PosMemRef
1450 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1451 retMem.push_back(dst);
1453 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1456 Value flatOut = dst;
1457 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1460 flatOut = rewriter.
create<memref::CollapseShapeOp>(loc, dst, reassoc);
1464 rewriter.
create<memref::CopyOp>(loc, srcMem, dstMem);
1470 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1471 return rewriter.
create<bufferization::ToTensorOp>(loc, v);
1474 retValues.append(retLen.begin(), retLen.end());
1483 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1489 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1504 Value dimSizesBuffer;
1505 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1506 dimSizesValues, dimSizesBuffer);
1511 {indexTp}, {reader}, EmitCInterface::Off)
1516 Value dim2lvlBuffer;
1517 Value lvl2dimBuffer;
1518 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1519 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1522 Value sizeHint = nse;
1525 lvlSizesValues, fields);
1532 const Type elemTp = dstTp.getElementType();
1533 const Type crdTp = dstTp.getCrdType();
1534 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1539 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1545 const Level lvlRank = dstTp.getLvlRank();
1546 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1549 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1551 rewriter.
create<scf::IfOp>(loc, notSorted,
false);
1556 SparseTensorSortKind::HybridQuickSort);
1563 const Type posTp = dstTp.getPosType();
1564 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1565 rewriter.
create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1568 Value coordinatesSize = rewriter.
create<arith::MulIOp>(
1576 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1577 EmitCInterface::Off);
1585 struct SparseHasRuntimeLibraryConverter
1589 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1608 bool createSparseDeallocs,
bool enableBufferInitialization) {
1610 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1611 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1612 SparseCastConverter, SparseExtractSliceConverter,
1613 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1614 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1615 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1616 StorageSpecifierKind::DimOffset>,
1617 SparseSliceGetterOpConverter<ToSliceStrideOp,
1618 StorageSpecifierKind::DimStride>,
1619 SparseToPositionsConverter, SparseToCoordinatesConverter,
1620 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1621 SparseConvertConverter, SparseNewConverter,
1622 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1624 patterns.
add<SparseTensorDeallocConverter>(
1625 typeConverter, patterns.
getContext(), createSparseDeallocs);
1626 patterns.
add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1627 typeConverter, patterns.
getContext(), enableBufferInitialization);
static void flattenOperands(ValueRange operands, SmallVectorImpl< Value > &flattened)
Flattens a list of operands that may contain sparse tensors.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
bool isLooseCompressedLT(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor)
bool isDenseLT(LevelType lt)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields)
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor)
Generates code to retrieve the values size for the sparse tensor.
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.