49 for (
auto operand : operands) {
56 flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
58 flattened.push_back(operand);
66 return builder.
create<memref::LoadOp>(loc, mem, idx);
73 val =
genCast(builder, loc, val,
74 cast<ShapedType>(mem.
getType()).getElementType());
75 builder.
create<memref::StoreOp>(loc, val, mem, idx);
86 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lower, upper, one, fields);
87 for (
unsigned i = 0, e = fields.size(); i < e; i++)
88 fields[i] = forOp.getRegionIterArg(i);
102 auto pushBackOp = builder.
create<PushBackOp>(
104 genCast(builder, loc, value, etp), repeat);
108 pushBackOp.getNewSize());
117 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
128 linear = builder.
create<arith::MulIOp>(loc, linear, two);
141 linear = builder.
create<arith::MulIOp>(loc, linear, size);
146 std::nullopt, valZero, linear);
151 MemRefType memRefType,
Value sz,
153 Value buffer = builder.
create<memref::AllocOp>(loc, memRefType, sz);
154 Type elemType = memRefType.getElementType();
157 builder.
create<linalg::FillOp>(loc, fillValue, buffer);
167 dimSizesValues.clear();
168 dimSizesValues.reserve(dimRank);
171 dimSizesValues.push_back(ShapedType::isDynamic(sz)
190 Value posHeuristic, crdHeuristic, valHeuristic;
192 valHeuristic = lvlSizesValues[0];
193 for (
Level lvl = 1; lvl < lvlRank; lvl++)
195 builder.
create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196 }
else if (sizeHint) {
199 crdHeuristic = builder.
create<arith::MulIOp>(
202 posHeuristic = builder.
create<arith::AddIOp>(
204 crdHeuristic = sizeHint;
206 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
208 valHeuristic = sizeHint;
210 posHeuristic = crdHeuristic = valHeuristic =
217 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
220 assert(fields.size() == fIdx);
228 posHeuristic, enableInit);
232 crdHeuristic, enableInit);
236 valHeuristic, enableInit);
240 fields.push_back(field);
249 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
250 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
286 assert(lvl < lvlRank &&
"Level is out of bounds");
287 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
288 "Level-rank mismatch");
296 const Value pp1 = builder.
create<arith::AddIOp>(loc, parentPos, one);
298 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
299 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
301 const Value crdStrideC =
304 crdStrideC ? builder.
create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
307 loc,
genCast(builder, loc, pstop, indexType), one);
309 Value lt = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
311 types.push_back(boolType);
312 scf::IfOp ifOp1 = builder.
create<scf::IfOp>(loc, types, lt,
true);
317 crdStrideC ? builder.
create<arith::MulIOp>(loc, plast, crdStrideC)
320 loc, arith::CmpIPredicate::eq,
genCast(builder, loc, crd, indexType),
322 builder.
create<scf::YieldOp>(loc, eq);
325 genStore(builder, loc, msz, positionsAtLvl, parentPos);
333 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
335 types.push_back(indexType);
338 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(loc, types, p,
true);
350 Value mszp1 = builder.
create<arith::AddIOp>(loc, msz, one);
351 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
355 if ((lvl + 1) < lvlRank)
365 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
366 desc.
setField(i, ifOp2.getResult(o++));
367 return ifOp2.getResult(o);
375 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
392 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
393 Value i = loop.getInductionVar();
394 Value oldv = loop.getRegionIterArg(0);
398 loc, arith::CmpIPredicate::eq, newv, posZero);
402 genStore(builder, loc, oldv, posMemRef, i);
403 builder.
create<scf::YieldOp>(loc, oldv);
405 builder.
create<scf::YieldOp>(loc, newv);
407 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
420 auto elemTp = llvm::cast<MemRefType>(mem.
getType()).getElementType();
422 .
create<memref::SubViewOp>(
437 for (
unsigned i = 0; i < batchLvls; i++)
440 for (
int i = batchLvls, e = srcTp.getRank(); i < e; i++)
441 ret.back().push_back(i);
453 class SparseInsertGenerator
473 const Level lvlRank = stt.getLvlRank();
478 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
479 Value value = args.back();
482 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
483 const auto lt = stt.getLvlType(lvl);
494 parentPos = builder.
create<arith::MulIOp>(loc, parentPos, two);
497 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
511 Value mult = builder.
create<arith::MulIOp>(loc, size, parentPos);
512 parentPos = builder.
create<arith::AddIOp>(loc, mult, coords[lvl]);
516 if (!stt.isDenseLvl(lvlRank - 1))
518 std::nullopt, value);
524 std::string getMangledFuncName() {
527 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
530 llvm::raw_svector_ostream nameOstream(nameBuffer);
531 nameOstream << kInsertFuncNamePrefix;
532 const Level lvlRank = stt.getLvlRank();
533 for (
Level l = 0; l < lvlRank; l++) {
537 lvlType.begin(), lvlType.end(),
538 [](
char c) { return c ==
'(' || c ==
','; },
'_');
539 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
540 nameOstream << lvlType <<
"_";
545 for (
const auto sz : stt.getDimShape())
546 nameOstream << sz <<
"_";
548 if (!stt.isIdentity())
549 nameOstream << stt.getDimToLvl() <<
"_";
550 nameOstream << stt.getElementType() <<
"_";
551 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
552 return nameOstream.str().str();
564 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
580 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
595 auto newCall = rewriter.
create<func::CallOp>(loc, op.getCallee(),
596 finalRetTy, flattened);
601 unsigned retOffset = 0;
606 assert(retOffset < newCall.getNumResults());
607 auto retType = ret.getType();
608 if (
failed(typeConverter->convertType(retType, sparseFlat)))
609 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
612 assert(!sparseFlat.empty());
613 if (sparseFlat.size() > 1) {
614 auto flatSize = sparseFlat.size();
616 newCall.result_begin() + retOffset,
617 newCall.result_begin() + retOffset + flatSize));
618 castedRet.push_back(
genTuple(rewriter, loc, retType, fields));
619 retOffset += flatSize;
622 castedRet.push_back(newCall.getResult(retOffset));
639 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
641 std::optional<int64_t> lvl = op.getConstantLvlIndex();
657 matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
687 rewriter.
replaceOp(op, adaptor.getInputCoo());
692 template <
typename Op, StorageSpecifierKind kind>
697 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
702 op.getDim().getZExtValue());
714 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
719 if (!encDst || encDst != encSrc)
721 rewriter.
replaceOp(op, adaptor.getOperands());
730 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
733 rewriter.
replaceOp(op, adaptor.getSource());
739 class SparseTensorAllocConverter
746 enableBufferInitialization(enableInit) {}
749 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
752 if (!resType.hasEncoding())
763 auto memrefTp = cast<MemRefType>(field.getType());
764 auto size = rewriter.
create<memref::DimOp>(loc, field, 0);
767 rewriter.
create<memref::CopyOp>(loc, field, copied);
768 fields.push_back(copied);
777 if (!resType.isIdentity()) {
779 op,
"try run --sparse-reinterpret-map before codegen");
787 Value sizeHint = op.getSizeHint();
790 sizeHint, lvlSizesValues, fields);
798 bool enableBufferInitialization;
808 enableBufferInitialization(enableInit) {}
811 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
814 if (!resType.hasEncoding())
817 if (!resType.isIdentity()) {
819 op,
"try run --sparse-reinterpret-map before codegen");
831 sizeHint, lvlSizesValues, fields);
839 bool enableBufferInitialization;
843 class SparseTensorDeallocConverter
850 createDeallocs(createDeallocs) {}
853 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
861 if (createDeallocs) {
867 rewriter.
create<memref::DeallocOp>(loc, input);
874 const bool createDeallocs;
882 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
887 if (op.getHasInserts())
900 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
907 Type eltType = srcType.getElementType();
915 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
917 const auto genAlloc = [&](
Type t) {
924 Value values = genAlloc(eltType);
925 Value filled = genAlloc(boolType);
926 Value added = genAlloc(idxType);
933 rewriter.
create<linalg::FillOp>(
936 rewriter.
create<linalg::FillOp>(
941 rewriter.
replaceOp(op, {values, filled, added, zero});
951 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
956 Value values = adaptor.getValues();
957 Value filled = adaptor.getFilled();
958 Value added = adaptor.getAdded();
959 Value count = adaptor.getCount();
961 Type eltType = dstType.getElementType();
965 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
968 rewriter.
getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
985 Value i = loop.getInductionVar();
991 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
992 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
993 params.push_back(crd);
994 params.push_back(value);
995 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1000 rewriter.
create<scf::YieldOp>(loc, insertRet);
1003 Value result =
genTuple(rewriter, loc, dstType, loop->getResults());
1007 rewriter.
create<memref::DeallocOp>(loc, values);
1008 rewriter.
create<memref::DeallocOp>(loc, filled);
1009 rewriter.
create<memref::DeallocOp>(loc, added);
1021 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
1024 if (!stt.hasEncoding())
1026 assert(stt.isIdentity() &&
"Run reinterpret-map before conversion.");
1032 params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
1033 params.push_back(adaptor.getScalar());
1034 SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
1039 genTuple(rewriter, loc, op.getDest().getType(), ret));
1047 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1050 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1062 class SparseToCoordinatesConverter
1065 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1068 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1075 op, desc.getCrdMemRefOrView(rewriter, op.
getLoc(), op.getLevel()));
1082 class SparseToCoordinatesBufferConverter
1085 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1088 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1103 using OpAdaptor =
typename ToValuesOp::Adaptor;
1106 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1122 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1125 SparseTensorEncodingAttr encSrc =
1129 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1133 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1139 Type srcElemTp = op.getSource().getType().getElementType();
1141 if (retElemTp == srcElemTp && encDst == encSrc) {
1142 rewriter.
replaceOp(op, adaptor.getSource());
1159 [&rewriter, &fields, srcDesc,
1163 if (fKind == SparseTensorFieldKind::StorageSpec) {
1164 fields.push_back(srcDesc.getSpecifier());
1167 Value srcMem = srcDesc.getMemRefField(fIdx);
1171 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1172 auto dstMem = rewriter.create<memref::AllocOp>(
1173 loc, cast<MemRefType>(fTp), sz);
1174 if (fTp != srcMem.getType()) {
1177 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1178 constantIndex(rewriter, loc, 1),
1179 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1181 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1182 Value casted = genCast(builder, loc, v,
1183 dstMem.getType().getElementType());
1184 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1190 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1192 fields.push_back(dstMem);
1203 class SparseExtractSliceConverter
1208 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1215 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1217 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1222 auto newSpec = rewriter.
create<StorageSpecifierInitOp>(
1228 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1243 assert(srcEnc.isIdentity());
1260 class SparseNumberOfEntriesConverter
1265 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1279 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1288 [&rewriter, &fields, &op, &stt,
1291 assert(fields.size() == fIdx);
1292 if (fKind == SparseTensorFieldKind::StorageSpec) {
1294 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1297 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1299 : op.getLevels()[fIdx];
1302 if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
1305 mem.getType(), stt.getBatchLvlRank());
1306 mem = rewriter.
create<memref::CastOp>(
1308 rewriter.
create<memref::CollapseShapeOp>(loc, mem, reassoc));
1310 mem = rewriter.
create<memref::CastOp>(loc, fType, mem);
1312 fields.push_back(mem);
1324 Level trailCOOStart = stt.getAoSCOOStart();
1325 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1327 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1328 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1331 auto lvlSize =
constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
1332 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1335 if (lvl > trailCOOStart)
1341 if (lt.
isa<LevelFormat::Dense>()) {
1342 memSize = rewriter.
create<arith::MulIOp>(loc, lvlSize, memSize);
1343 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1346 if (lt.
isa<LevelFormat::Batch>()) {
1356 memSize = rewriter.
create<arith::MulIOp>(loc, memSize, c2);
1357 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1361 memSize = rewriter.
create<arith::AddIOp>(loc, memSize, c1);
1369 batched.push_back(posBack);
1371 posBack = rewriter.
create<arith::SubIOp>(loc, posBack, c1);
1375 if (lvl == trailCOOStart) {
1390 struct SparseDisassembleOpConverter
1398 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1408 if (fKind == SparseTensorFieldKind::StorageSpec)
1413 if (fKind == SparseTensorFieldKind::ValMemRef) {
1416 dst =
genToMemref(rewriter, loc, op.getOutValues());
1418 retMem.push_back(dst);
1419 Type valLenTp = op.getValLen().getType();
1422 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1423 fKind == SparseTensorFieldKind::CrdMemRef);
1425 sz = fKind == SparseTensorFieldKind::PosMemRef
1429 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1430 retMem.push_back(dst);
1432 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1435 Value flatOut = dst;
1436 if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
1439 flatOut = rewriter.
create<memref::CollapseShapeOp>(loc, dst, reassoc);
1443 rewriter.
create<memref::CopyOp>(loc, srcMem, dstMem);
1449 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1450 return rewriter.
create<bufferization::ToTensorOp>(loc, v);
1453 retValues.append(retLen.begin(), retLen.end());
1462 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1468 if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0)
1483 Value dimSizesBuffer;
1484 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1485 dimSizesValues, dimSizesBuffer);
1490 {indexTp}, {reader}, EmitCInterface::Off)
1495 Value dim2lvlBuffer;
1496 Value lvl2dimBuffer;
1497 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1498 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1501 Value sizeHint = nse;
1504 lvlSizesValues, fields);
1511 const Type elemTp = dstTp.getElementType();
1512 const Type crdTp = dstTp.getCrdType();
1513 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1518 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1524 const Level lvlRank = dstTp.getLvlRank();
1525 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1528 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1530 rewriter.
create<scf::IfOp>(loc, notSorted,
false);
1535 SparseTensorSortKind::HybridQuickSort);
1542 const Type posTp = dstTp.getPosType();
1543 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1544 rewriter.
create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1547 Value coordinatesSize = rewriter.
create<arith::MulIOp>(
1555 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1556 EmitCInterface::Off);
1564 struct SparseHasRuntimeLibraryConverter
1568 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
1587 bool createSparseDeallocs,
bool enableBufferInitialization) {
1589 SparseAssembleOpConverter, SparseDisassembleOpConverter,
1590 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1591 SparseCastConverter, SparseExtractSliceConverter,
1592 SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter,
1593 SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter,
1594 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1595 StorageSpecifierKind::DimOffset>,
1596 SparseSliceGetterOpConverter<ToSliceStrideOp,
1597 StorageSpecifierKind::DimStride>,
1598 SparseToPositionsConverter, SparseToCoordinatesConverter,
1599 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1600 SparseConvertConverter, SparseNewConverter,
1601 SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>(
1603 patterns.
add<SparseTensorDeallocConverter>(
1604 typeConverter, patterns.
getContext(), createSparseDeallocs);
1605 patterns.
add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1606 typeConverter, patterns.
getContext(), enableBufferInitialization);
static void flattenOperands(ValueRange operands, SmallVectorImpl< Value > &flattened)
Flattens a list of operands that may contain sparse tensors.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static SmallVector< ReassociationIndices > getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls)
Creates the reassociation array.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
bool isCOOType(Level startLvl=0, bool isUnique=true) const
Returns true iff this sparse tensor type has a trailing COO region starting at the given level.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
bool isDenseLvl(Level l) const
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool isWithCrdLT(LevelType lt)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
bool isWithPosLT(LevelType lt)
std::string toMLIRString(LevelType lt)
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
bool isSingletonLT(LevelType lt)
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
bool isCompressedLT(LevelType lt)
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
bool isLooseCompressedLT(LevelType lt)
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s)
Generates a pointer/index load from the sparse storage scheme.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor)
bool isDenseLT(LevelType lt)
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields)
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor)
Generates code to retrieve the values size for the sparse tensor.
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
bool isNOutOfMLT(LevelType lt)
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
constexpr bool isa() const
Check if the LevelType is in the LevelFormat.