49 for (
auto operand : operands) {
56 flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
58 flattened.push_back(operand);
66 return builder.
create<memref::LoadOp>(loc, mem, idx);
73 val =
genCast(builder, loc, val,
74 cast<ShapedType>(mem.
getType()).getElementType());
75 builder.
create<memref::StoreOp>(loc, val, mem, idx);
86 scf::ForOp forOp = builder.
create<scf::ForOp>(loc, lower, upper, one, fields);
87 for (
unsigned i = 0, e = fields.size(); i < e; i++)
88 fields[i] = forOp.getRegionIterArg(i);
102 auto pushBackOp = builder.
create<PushBackOp>(
104 genCast(builder, loc, value, etp), repeat);
108 pushBackOp.getNewSize());
117 for (
Level lvl = startLvl; lvl < lvlRank; lvl++) {
128 linear = builder.
create<arith::MulIOp>(loc, linear, two);
141 linear = builder.
create<arith::MulIOp>(loc, linear, size);
146 std::nullopt, valZero, linear);
151 MemRefType memRefType,
Value sz,
153 Value buffer = builder.
create<memref::AllocOp>(loc, memRefType, sz);
154 Type elemType = memRefType.getElementType();
157 builder.
create<linalg::FillOp>(loc, fillValue, buffer);
167 dimSizesValues.clear();
168 dimSizesValues.reserve(dimRank);
171 dimSizesValues.push_back(ShapedType::isDynamic(sz)
190 Value posHeuristic, crdHeuristic, valHeuristic;
192 valHeuristic = lvlSizesValues[0];
193 for (
Level lvl = 1; lvl < lvlRank; lvl++)
195 builder.
create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196 }
else if (sizeHint) {
199 crdHeuristic = builder.
create<arith::MulIOp>(
202 posHeuristic = builder.
create<arith::AddIOp>(
204 crdHeuristic = sizeHint;
206 posHeuristic = crdHeuristic =
constantIndex(builder, loc, 16);
208 valHeuristic = sizeHint;
210 posHeuristic = crdHeuristic = valHeuristic =
217 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
220 assert(fields.size() == fIdx);
228 posHeuristic, enableInit);
232 crdHeuristic, enableInit);
236 valHeuristic, enableInit);
240 fields.push_back(field);
249 for (
Level lvl = 0, lvlRank = stt.
getLvlRank(); lvl < lvlRank; lvl++) {
250 desc.setLvlSize(builder, loc, lvl, lvlSizesValues[lvl]);
286 assert(lvl < lvlRank &&
"Level is out of bounds");
287 assert(lvlCoords.size() ==
static_cast<size_t>(lvlRank) &&
288 "Level-rank mismatch");
296 const Value pp1 = builder.
create<arith::AddIOp>(loc, parentPos, one);
298 const Value pstart =
genLoad(builder, loc, positionsAtLvl, parentPos);
299 const Value pstop =
genLoad(builder, loc, positionsAtLvl, pp1);
301 const Value crdStrideC =
304 crdStrideC ? builder.
create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
307 loc,
genCast(builder, loc, pstop, indexType), one);
309 Value lt = builder.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
311 types.push_back(boolType);
312 scf::IfOp ifOp1 = builder.
create<scf::IfOp>(loc, types, lt,
true);
317 crdStrideC ? builder.
create<arith::MulIOp>(loc, plast, crdStrideC)
320 loc, arith::CmpIPredicate::eq,
genCast(builder, loc, crd, indexType),
322 builder.
create<scf::YieldOp>(loc, eq);
325 genStore(builder, loc, msz, positionsAtLvl, parentPos);
333 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
335 types.push_back(indexType);
338 scf::IfOp ifOp2 = builder.
create<scf::IfOp>(loc, types, p,
true);
350 Value mszp1 = builder.
create<arith::AddIOp>(loc, msz, one);
351 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
355 if ((lvl + 1) < lvlRank)
365 for (
unsigned i = 0, e = desc.
getNumFields(); i < e; i++)
366 desc.
setField(i, ifOp2.getResult(o++));
367 return ifOp2.getResult(o);
375 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
392 scf::ForOp loop =
createFor(builder, loc, hi, inits, one);
393 Value i = loop.getInductionVar();
394 Value oldv = loop.getRegionIterArg(0);
398 loc, arith::CmpIPredicate::eq, newv, posZero);
402 genStore(builder, loc, oldv, posMemRef, i);
403 builder.
create<scf::YieldOp>(loc, oldv);
405 builder.
create<scf::YieldOp>(loc, newv);
407 builder.
create<scf::YieldOp>(loc, ifOp.getResult(0));
420 auto elemTp = llvm::cast<MemRefType>(mem.
getType()).getElementType();
422 .
create<memref::SubViewOp>(
434 for (
int i = 0, e = srcTp.getRank(); i < e; i++)
435 reassociation.push_back(i);
436 return reassociation;
446 class SparseInsertGenerator
466 const Level lvlRank = stt.getLvlRank();
471 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
472 Value value = args.back();
475 for (
Level lvl = 0; lvl < lvlRank; lvl++) {
476 const auto lt = stt.getLvlType(lvl);
487 parentPos = builder.
create<arith::MulIOp>(loc, parentPos, two);
490 genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
504 Value mult = builder.
create<arith::MulIOp>(loc, size, parentPos);
505 parentPos = builder.
create<arith::AddIOp>(loc, mult, coords[lvl]);
509 if (!stt.isDenseLvl(lvlRank - 1))
511 std::nullopt, value);
517 std::string getMangledFuncName() {
520 constexpr
const char kInsertFuncNamePrefix[] =
"_insert_";
523 llvm::raw_svector_ostream nameOstream(nameBuffer);
524 nameOstream << kInsertFuncNamePrefix;
525 const Level lvlRank = stt.getLvlRank();
526 for (
Level l = 0; l < lvlRank; l++) {
530 lvlType.begin(), lvlType.end(),
531 [](
char c) { return c ==
'(' || c ==
','; },
'_');
532 llvm::erase_if(lvlType, [](
char c) {
return c ==
')' || c ==
' '; });
533 nameOstream << lvlType <<
"_";
538 for (
const auto sz : stt.getDimShape())
539 nameOstream << sz <<
"_";
541 if (!stt.isIdentity())
542 nameOstream << stt.getDimToLvl() <<
"_";
543 nameOstream << stt.getElementType() <<
"_";
544 nameOstream << stt.getCrdWidth() <<
"_" << stt.getPosWidth();
545 return nameOstream.str().str();
557 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
573 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
588 auto newCall = rewriter.
create<func::CallOp>(loc, op.getCallee(),
589 finalRetTy, flattened);
594 unsigned retOffset = 0;
599 assert(retOffset < newCall.getNumResults());
600 auto retType = ret.getType();
601 if (
failed(typeConverter->convertType(retType, sparseFlat)))
602 llvm_unreachable(
"Failed to convert type in sparse tensor codegen");
605 assert(!sparseFlat.empty());
606 if (sparseFlat.size() > 1) {
607 auto flatSize = sparseFlat.size();
609 newCall.result_begin() + retOffset,
610 newCall.result_begin() + retOffset + flatSize));
611 castedRet.push_back(
genTuple(rewriter, loc, retType, fields));
612 retOffset += flatSize;
615 castedRet.push_back(newCall.getResult(retOffset));
632 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
634 std::optional<int64_t> lvl = op.getConstantLvlIndex();
650 matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
681 rewriter.
replaceOp(op, adaptor.getInputCoo());
686 template <
typename Op, StorageSpecifierKind kind>
691 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
696 op.getDim().getZExtValue());
708 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
713 if (!encDst || encDst != encSrc)
715 rewriter.
replaceOp(op, adaptor.getOperands());
724 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
727 rewriter.
replaceOp(op, adaptor.getSource());
733 class SparseTensorAllocConverter
740 enableBufferInitialization(enableInit) {}
743 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
746 if (!resType.hasEncoding())
757 auto memrefTp = cast<MemRefType>(field.getType());
758 auto size = rewriter.
create<memref::DimOp>(loc, field, 0);
761 rewriter.
create<memref::CopyOp>(loc, field, copied);
762 fields.push_back(copied);
771 if (!resType.isIdentity()) {
773 op,
"try run --sparse-reinterpret-map before codegen");
781 Value sizeHint = op.getSizeHint();
784 sizeHint, lvlSizesValues, fields);
792 bool enableBufferInitialization;
802 enableBufferInitialization(enableInit) {}
805 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
808 if (!resType.hasEncoding())
811 if (!resType.isIdentity()) {
813 op,
"try run --sparse-reinterpret-map before codegen");
825 sizeHint, lvlSizesValues, fields);
833 bool enableBufferInitialization;
837 class SparseTensorDeallocConverter
844 createDeallocs(createDeallocs) {}
847 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
855 if (createDeallocs) {
861 rewriter.
create<memref::DeallocOp>(loc, input);
868 const bool createDeallocs;
876 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
881 if (op.getHasInserts())
894 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
901 Type eltType = srcType.getElementType();
909 const auto sz = desc.
getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
911 const auto genAlloc = [&](
Type t) {
918 Value values = genAlloc(eltType);
919 Value filled = genAlloc(boolType);
920 Value added = genAlloc(idxType);
927 rewriter.
create<linalg::FillOp>(
930 rewriter.
create<linalg::FillOp>(
935 rewriter.
replaceOp(op, {values, filled, added, zero});
945 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
950 Value values = adaptor.getValues();
951 Value filled = adaptor.getFilled();
952 Value added = adaptor.getAdded();
953 Value count = adaptor.getCount();
955 Type eltType = dstType.getElementType();
959 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
962 rewriter.
getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
979 Value i = loop.getInductionVar();
985 llvm::map_range(desc.
getFields(), [](
Value v) { return v.getType(); }));
986 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
987 params.push_back(crd);
988 params.push_back(value);
989 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
994 rewriter.
create<scf::YieldOp>(loc, insertRet);
997 Value result =
genTuple(rewriter, loc, dstType, loop->getResults());
1001 rewriter.
create<memref::DeallocOp>(loc, values);
1002 rewriter.
create<memref::DeallocOp>(loc, filled);
1003 rewriter.
create<memref::DeallocOp>(loc, added);
1015 matchAndRewrite(InsertOp op, OpAdaptor adaptor,
1021 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
1022 params.push_back(adaptor.getValue());
1023 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1028 genTuple(rewriter, loc, op.getTensor().getType(), ret));
1036 using OpAdaptor =
typename ToPositionsOp::Adaptor;
1039 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1051 class SparseToCoordinatesConverter
1054 using OpAdaptor =
typename ToCoordinatesOp::Adaptor;
1057 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1064 Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
1070 if (resType != field.
getType())
1071 field = rewriter.
create<memref::CastOp>(loc, resType, field);
1079 class SparseToCoordinatesBufferConverter
1082 using OpAdaptor =
typename ToCoordinatesBufferOp::Adaptor;
1085 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1100 using OpAdaptor =
typename ToValuesOp::Adaptor;
1103 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1119 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1122 SparseTensorEncodingAttr encSrc =
1126 assert(!encDst.isSlice() &&
"Cannot convert to a sparse tensor slices.");
1130 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1136 Type srcElemTp = op.getSource().getType().getElementType();
1138 if (retElemTp == srcElemTp && encDst == encSrc) {
1139 rewriter.
replaceOp(op, adaptor.getSource());
1156 [&rewriter, &fields, srcDesc,
1160 if (fKind == SparseTensorFieldKind::StorageSpec) {
1161 fields.push_back(srcDesc.getSpecifier());
1164 Value srcMem = srcDesc.getMemRefField(fIdx);
1168 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1169 auto dstMem = rewriter.create<memref::AllocOp>(
1170 loc, cast<MemRefType>(fTp), sz);
1171 if (fTp != srcMem.getType()) {
1174 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1175 constantIndex(rewriter, loc, 1),
1176 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1178 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1179 Value casted = genCast(builder, loc, v,
1180 dstMem.getType().getElementType());
1181 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1187 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1189 fields.push_back(dstMem);
1200 class SparseExtractSliceConverter
1205 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1212 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1214 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1219 auto newSpec = rewriter.
create<StorageSpecifierInitOp>(
1225 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1240 assert(srcEnc.isIdentity());
1257 class SparseNumberOfEntriesConverter
1262 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1276 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1285 [&rewriter, &fields, &op, &stt,
1288 assert(fields.size() == fIdx);
1289 if (fKind == SparseTensorFieldKind::StorageSpec) {
1291 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1294 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1296 : op.getLevels()[fIdx];
1299 if (mem.getType().getRank() > 1) {
1302 mem = rewriter.
create<memref::CastOp>(
1304 rewriter.
create<memref::CollapseShapeOp>(loc, mem, reassoc));
1306 mem = rewriter.
create<memref::CastOp>(loc, fType, mem);
1308 fields.push_back(mem);
1321 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1323 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1324 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1328 auto lvlSize =
constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
1329 desc.
setLvlSize(rewriter, loc, lvl, lvlSize);
1332 if (lvl > trailCOOStart)
1339 memSize = rewriter.
create<arith::MulIOp>(loc, lvlSize, memSize);
1340 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1347 memSize = rewriter.
create<arith::MulIOp>(loc, memSize, c2);
1348 posBack = rewriter.
create<arith::SubIOp>(loc, memSize, c1);
1352 memSize = rewriter.
create<arith::AddIOp>(loc, memSize, c1);
1357 posBack = rewriter.
create<arith::SubIOp>(loc, posBack, c1);
1361 if (lvl == trailCOOStart) {
1376 struct SparseDisassembleOpConverter
1384 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1394 if (fKind == SparseTensorFieldKind::StorageSpec)
1399 if (fKind == SparseTensorFieldKind::ValMemRef) {
1402 dst =
genToMemref(rewriter, loc, op.getOutValues());
1407 retMem.insert(retMem.begin(), dst);
1408 Type valLenTp = op.getValLen().getType();
1409 retLen.insert(retLen.begin(),
1412 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1413 fKind == SparseTensorFieldKind::CrdMemRef);
1415 sz = fKind == SparseTensorFieldKind::PosMemRef
1419 dst =
genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1420 retMem.push_back(dst);
1422 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1425 Value flatOut = dst;
1426 if (dst.getType().getRank() != 1) {
1428 flatOut = rewriter.
create<memref::CollapseShapeOp>(loc, dst, reassoc);
1432 rewriter.
create<memref::CopyOp>(loc, srcMem, dstMem);
1438 llvm::map_range(retMem, [&rewriter, loc](
Value v) ->
Value {
1439 return rewriter.
create<bufferization::ToTensorOp>(loc, v);
1442 retValues.append(retLen.begin(), retLen.end());
1451 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
1457 if (!dstTp.hasEncoding() ||
getCOOStart(dstTp.getEncoding()) != 0)
1472 Value dimSizesBuffer;
1473 Value reader =
genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1474 dimSizesValues, dimSizesBuffer);
1479 {indexTp}, {reader}, EmitCInterface::Off)
1484 Value dim2lvlBuffer;
1485 Value lvl2dimBuffer;
1486 genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer,
1487 lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
1490 Value sizeHint = nse;
1493 lvlSizesValues, fields);
1500 const Type elemTp = dstTp.getElementType();
1501 const Type crdTp = dstTp.getCrdType();
1502 SmallString<32> readToBuffersFuncName{
"getSparseTensorReaderReadToBuffers",
1507 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1513 const Level lvlRank = dstTp.getLvlRank();
1514 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1517 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1519 rewriter.
create<scf::IfOp>(loc, notSorted,
false);
1524 SparseTensorSortKind::HybridQuickSort);
1531 const Type posTp = dstTp.getPosType();
1532 const Value posNse =
genCast(rewriter, loc, nse, posTp);
1533 rewriter.
create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1536 Value coordinatesSize = rewriter.
create<arith::MulIOp>(
1544 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
1545 EmitCInterface::Off);
1563 bool createSparseDeallocs,
bool enableBufferInitialization) {
1564 patterns.
add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
1565 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1566 SparseCastConverter, SparseExtractSliceConverter,
1567 SparseTensorLoadConverter, SparseExpandConverter,
1568 SparseCompressConverter, SparseInsertConverter,
1569 SparseReorderCOOConverter, SparseReMapConverter,
1570 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1571 StorageSpecifierKind::DimOffset>,
1572 SparseSliceGetterOpConverter<ToSliceStrideOp,
1573 StorageSpecifierKind::DimStride>,
1574 SparseToPositionsConverter, SparseToCoordinatesConverter,
1575 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1576 SparseConvertConverter, SparseNewConverter,
1577 SparseNumberOfEntriesConverter>(typeConverter,
1579 patterns.
add<SparseTensorDeallocConverter>(
1580 typeConverter, patterns.
getContext(), createSparseDeallocs);
1581 patterns.
add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1582 typeConverter, patterns.
getContext(), enableBufferInitialization);
static void flattenOperands(ValueRange operands, SmallVectorImpl< Value > &flattened)
Flattens a list of operands that may contain sparse tensors.
static void createAllocFields(OpBuilder &builder, Location loc, SparseTensorType stt, bool enableInit, Value sizeHint, SmallVectorImpl< Value > &lvlSizesValues, SmallVectorImpl< Value > &fields)
Creates allocation for each field in sparse tensor type.
static ReassociationIndices getReassociationForFlattening(ShapedType srcTp)
Creates the reassociation array.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, MutableArrayRef< Value > fields, Value lower=Value())
Creates a straightforward counting for-loop.
static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc)
Generates insertion finalization code.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, Level startLvl)
Generates code that allocates a sparse storage scheme for given rank.
static Value genCompressed(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, ValueRange lvlCoords, Value, Value parentPos, Level lvl)
Helper method that generates block specific to compressed case:
static Value createAllocation(OpBuilder &builder, Location loc, MemRefType memRefType, Value sz, bool enableInit)
Creates allocation operation.
static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, SparseTensorFieldKind kind, std::optional< Level > lvl, Value value, Value repeat=Value())
Creates a push back operation.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz)
Generates a subview into the sizes.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
static void createDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, ValueRange dynSizes, SmallVectorImpl< Value > &dimSizesValues)
Creates the dim sizes array, filling in from dynamic sizes.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A helper class to simplify lowering operations with/without function calls.
Using SmallVector for mutable descriptor allows users to reuse it as a tmp buffers to append value fo...
void setMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl, Value v)
Adds additional setters for mutable descriptor, update the value for required field.
void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl, Value v)
void setSpecifier(Value newSpec)
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setValMemSize(OpBuilder &builder, Location loc, Value v)
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v)
void setField(FieldIndex fidx, Value v)
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v)
Value getSpecifier() const
Getters: get the value for required field.
RankedTensorType getRankedTensorType() const
ValueArrayRef getFields() const
std::pair< FieldIndex, unsigned > getCrdMemRefIndexAndStride(Level lvl) const
Value getValMemSize(OpBuilder &builder, Location loc) const
Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional< Level > lvl) const
StorageLayout getLayout() const
ValueRange getMemRefFields() const
Value getAOSMemRef() const
Type getMemRefElementType(SparseTensorFieldKind kind, std::optional< Level > lvl) const
unsigned getNumFields() const
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getMemRefField(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const
Value getValMemRef() const
Value getField(FieldIndex fidx) const
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const
Value getPosMemRef(Level lvl) const
Uses ValueRange for immutable descriptors.
static Value getInitValue(OpBuilder &builder, Location loc, SparseTensorType stt)
A wrapper around RankedTensorType, which has three goals:
Type getElementType() const
ArrayRef< Size > getDimShape() const
Returns the dimension-shape.
bool isAllOrdered() const
Returns true for tensors where every level is ordered.
Dimension getDimRank() const
Returns the dimension-rank.
bool isAllDense() const
Returns true for tensors where every level is dense.
bool hasSameDimToLvl(const SparseTensorType &other) const
Returns true iff the two types have the same mapping.
RankedTensorType getRankedTensorType() const
Explicitly convert to RankedTensorType.
bool isCompressedLvl(Level l) const
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
Returns the encoding (or the null-attribute for dense-tensors).
bool isDenseLvl(Level l) const
LevelType getLvlType(Level l) const
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
void foreachField(llvm::function_ref< bool(FieldIndex, SparseTensorFieldKind, Level, LevelType)>) const
For each field that will be allocated for the given sparse tensor encoding, calls the callback with t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
constexpr const char * toMLIRString(LevelType lt)
Returns string representation of the given dimension level type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Level getCOOStart(SparseTensorEncodingAttr enc)
Returns the starting level for a trailing COO region that spans at least two levels.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
constexpr bool isWithPosLT(LevelType lt)
Check if the LevelType needs positions array.
constexpr bool isLooseCompressedLT(LevelType lt)
Check if the LevelType is loose compressed (regardless of properties).
void foreachFieldAndTypeInSparseTensor(SparseTensorType, llvm::function_ref< bool(Type, FieldIndex, SparseTensorFieldKind, Level, LevelType)>)
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values)
Packs the given values as a "tuple" value.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
constexpr bool isWithCrdLT(LevelType lt)
Check if the LevelType needs coordinates array.
constexpr bool is2OutOf4LT(LevelType lt)
Check if the LevelType is 2OutOf4 (regardless of properties).
constexpr bool isDenseLT(LevelType lt)
Check if the LevelType is dense (regardless of properties).
bool isUniqueCOOType(Type tp)
Returns true iff the given type is a COO type where the last level is unique.
Operation * getTop(Operation *op)
Scans to top of generated loop.
constexpr bool isSingletonLT(LevelType lt)
Check if the LevelType is singleton (regardless of properties).
LevelType
This enum defines all the sparse representations supportable by the SparseTensor dialect.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s)
Generates a pointer/index load from the sparse storage scheme.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor)
constexpr bool isCompressedLT(LevelType lt)
Check if the LevelType is compressed (regardless of properties).
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SparseTensorFieldKind
===-------------------------------------------------------------------—===// The sparse tensor storag...
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl< Value > &fields)
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor)
Generates code to retrieve the values size for the sparse tensor.
StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind)
UnrealizedConversionCastOp getTuple(Value tensor)
Returns the "tuple" value of the adapted tensor.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization)
Sets up sparse tensor codegen rules.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.