43static std::optional<Type> convertSparseTensorTypes(
Type type) {
45 return LLVM::LLVMPointerType::get(type.
getContext());
54 StringRef name =
"sparseLvlSize";
66 StringRef name =
"sparseDimSize";
89 if (ShapedType::isStatic(sz))
95 return genLvlSizeCall(builder, loc,
tensor, lvl);
106 if (ShapedType::isStatic(sz))
109 return genDimSizeCall(builder, loc,
tensor, dim);
118 out.reserve(dimRank);
120 out.push_back(createOrFoldDimCall(builder, loc, stt,
tensor, d));
130 fillDimSizes(builder, loc, stt,
tensor, out);
139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140 return memref::AllocOp::create(rewriter, loc, memTp,
ValueRange{sz});
148 for (
const auto lt : stt.
getEncoding().getLvlTypes())
157 return memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, buf);
164 lvlBarePtrs.reserve(lvlTensors.size() + 1);
166 for (
const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171 Value idxPtr = memref::ExtractAlignedPointerAsIndexOp::create(
174 arith::IndexCastOp::create(builder, loc, builder.
getI64Type(), idxPtr);
183class NewCallParams final {
186 NewCallParams(OpBuilder &builder, Location loc)
192 NewCallParams &
genBuffers(SparseTensorType stt,
193 ArrayRef<Value> dimSizesValues,
194 Value dimSizesBuffer = Value()) {
195 assert(dimSizesValues.size() ==
static_cast<size_t>(stt.
getDimRank()));
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
199 params[kParamDimSizes] = dimSizesBuffer
202 SmallVector<Value> lvlSizesValues;
204 builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
210 params[kParamValTp] =
217 bool isInitialized()
const {
218 for (
unsigned i = 0; i < kNumStaticParams; ++i)
226 Value genNewCall(
Action action, Value ptr = Value()) {
227 assert(isInitialized() &&
"Must initialize before genNewCall");
228 StringRef name =
"newSparseTensor";
230 params[kParamPtr] = ptr ? ptr : LLVM::ZeroOp::create(builder, loc, pTp);
231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
236 static constexpr unsigned kNumStaticParams = 8;
237 static constexpr unsigned kNumDynamicParams = 2;
238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr unsigned kParamDimSizes = 0;
240 static constexpr unsigned kParamLvlSizes = 1;
241 static constexpr unsigned kParamLvlTypes = 2;
242 static constexpr unsigned kParamDim2Lvl = 3;
243 static constexpr unsigned kParamLvl2Dim = 4;
244 static constexpr unsigned kParamPosTp = 5;
245 static constexpr unsigned kParamCrdTp = 6;
246 static constexpr unsigned kParamValTp = 7;
247 static constexpr unsigned kParamAction = 8;
248 static constexpr unsigned kParamPtr = 9;
253 Value params[kNumParams];
260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
295 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
309class SparseReturnConverter :
public OpConversionPattern<func::ReturnOp> {
311 using OpConversionPattern::OpConversionPattern;
313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
315 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
321class SparseTensorLvlOpConverter :
public OpConversionPattern<LvlOp> {
323 using OpConversionPattern::OpConversionPattern;
325 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter)
const override {
333 std::optional<int64_t> lvl = op.getConstantLvlIndex();
340 Value src = adaptor.getOperands()[0];
341 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
347class SparseCastConverter :
public OpConversionPattern<tensor::CastOp> {
349 using OpConversionPattern::OpConversionPattern;
351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter)
const override {
356 if (!encDst || encDst != encSrc)
358 rewriter.replaceOp(op, adaptor.getOperands());
363class SparseReMapConverter :
public OpConversionPattern<ReinterpretMapOp> {
365 using OpConversionPattern::OpConversionPattern;
367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter)
const override {
370 rewriter.replaceOp(op, adaptor.getSource());
376class SparseTensorNewConverter :
public OpConversionPattern<NewOp> {
378 using OpConversionPattern::OpConversionPattern;
380 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
381 ConversionPatternRewriter &rewriter)
const override {
382 Location loc = op.getLoc();
388 return rewriter.notifyMatchFailure(op,
"unsupported element type");
390 SmallVector<Value> dimSizesValues;
391 Value dimSizesBuffer;
392 Value reader =
genReader(rewriter, loc, stt, adaptor.getOperands()[0],
393 dimSizesValues, dimSizesBuffer);
395 Value tensor = NewCallParams(rewriter, loc)
396 .genBuffers(stt, dimSizesValues, dimSizesBuffer)
397 .genNewCall(Action::kFromReader, reader);
399 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
400 EmitCInterface::Off);
401 rewriter.replaceOp(op, tensor);
408class SparseTensorAllocConverter
409 :
public OpConversionPattern<bufferization::AllocTensorOp> {
411 using OpConversionPattern::OpConversionPattern;
413 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
414 ConversionPatternRewriter &rewriter)
const override {
419 return rewriter.notifyMatchFailure(op,
"alloc copy not implemented");
421 Location loc = op.getLoc();
423 SmallVector<Value> dimSizesValues;
424 dimSizesValues.reserve(dimRank);
425 unsigned operandCtr = 0;
426 for (
Dimension d = 0; d < dimRank; d++) {
427 dimSizesValues.push_back(
429 ? adaptor.getOperands()[operandCtr++]
434 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
436 .genNewCall(Action::kEmpty));
442class SparseTensorEmptyConverter :
public OpConversionPattern<tensor::EmptyOp> {
444 using OpConversionPattern::OpConversionPattern;
446 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter)
const override {
448 Location loc = op.getLoc();
454 SmallVector<Value> dimSizesValues;
455 dimSizesValues.reserve(dimRank);
456 auto shape = op.getType().getShape();
457 unsigned operandCtr = 0;
458 for (
Dimension d = 0; d < dimRank; d++) {
460 ? adaptor.getOperands()[operandCtr++]
465 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
467 .genNewCall(Action::kEmpty));
473class SparseTensorReorderCOOConverter
474 :
public OpConversionPattern<ReorderCOOOp> {
476 using OpConversionPattern::OpConversionPattern;
479 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
480 ConversionPatternRewriter &rewriter)
const override {
481 const Location loc = op->getLoc();
485 const Value src = adaptor.getInputCoo();
487 NewCallParams params(rewriter, loc);
488 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
489 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
490 .genNewCall(Action::kSortCOOInPlace, src));
497class SparseTensorDeallocConverter
498 :
public OpConversionPattern<bufferization::DeallocTensorOp> {
500 using OpConversionPattern::OpConversionPattern;
502 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
503 ConversionPatternRewriter &rewriter)
const override {
506 StringRef name =
"delSparseTensor";
507 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
508 EmitCInterface::Off);
509 rewriter.eraseOp(op);
515class SparseTensorToPositionsConverter
516 :
public OpConversionPattern<ToPositionsOp> {
518 using OpConversionPattern::OpConversionPattern;
520 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
521 ConversionPatternRewriter &rewriter)
const override {
523 auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
524 adaptor.getTensor(), op.getLevel());
525 rewriter.replaceOp(op, poss);
531class SparseTensorToCoordinatesConverter
532 :
public OpConversionPattern<ToCoordinatesOp> {
534 using OpConversionPattern::OpConversionPattern;
536 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
537 ConversionPatternRewriter &rewriter)
const override {
538 const Location loc = op.getLoc();
540 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
544 if (op.getType() != crds.getType())
545 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
546 rewriter.replaceOp(op, crds);
552class SparseToCoordinatesBufferConverter
553 :
public OpConversionPattern<ToCoordinatesBufferOp> {
555 using OpConversionPattern::OpConversionPattern;
557 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter)
const override {
559 const Location loc = op.getLoc();
561 auto crds = genCoordinatesBufferCall(
565 if (op.getType() != crds.getType())
566 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
567 rewriter.replaceOp(op, crds);
573class SparseTensorToValuesConverter :
public OpConversionPattern<ToValuesOp> {
575 using OpConversionPattern::OpConversionPattern;
577 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
578 ConversionPatternRewriter &rewriter)
const override {
580 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
581 rewriter.replaceOp(op, vals);
587class SparseNumberOfEntriesConverter
588 :
public OpConversionPattern<NumberOfEntriesOp> {
590 using OpConversionPattern::OpConversionPattern;
592 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
593 ConversionPatternRewriter &rewriter)
const override {
596 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
598 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
604class SparseTensorLoadConverter :
public OpConversionPattern<LoadOp> {
606 using OpConversionPattern::OpConversionPattern;
608 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter)
const override {
610 if (op.getHasInserts()) {
612 StringRef name =
"endLexInsert";
613 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
614 EmitCInterface::Off);
616 rewriter.replaceOp(op, adaptor.getOperands());
622class SparseTensorInsertConverter
623 :
public OpConversionPattern<tensor::InsertOp> {
625 using OpConversionPattern::OpConversionPattern;
627 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter)
const override {
632 Location loc = op->getLoc();
639 assert(stt.
isIdentity() &&
"Run reinterpret-map before conversion.");
642 Value lvlCoords, vref;
644 OpBuilder::InsertionGuard guard(rewriter);
645 Operation *loop = op;
650 if (llvm::isa<LoopLikeOpInterface>(loop)) {
652 rewriter.setInsertionPoint(loop);
654 lvlCoords =
genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
657 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
658 memref::StoreOp::create(rewriter, loc, adaptor.getScalar(), vref);
661 {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
662 rewriter.replaceOp(op, adaptor.getDest());
668class SparseTensorExpandConverter :
public OpConversionPattern<ExpandOp> {
670 using OpConversionPattern::OpConversionPattern;
672 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter)
const override {
674 Location loc = op->getLoc();
676 Type eltType = srcTp.getElementType();
677 Type boolType = rewriter.getIntegerType(1);
678 Type idxType = rewriter.getIndexType();
680 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
682 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
683 srcTp.getLvlRank() - 1);
687 Value values = genAlloc(rewriter, loc, sz, eltType);
688 Value filled = genAlloc(rewriter, loc, sz, boolType);
689 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
696 linalg::FillOp::create(rewriter, loc,
699 linalg::FillOp::create(rewriter, loc,
703 assert(op.getNumResults() == 4);
704 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
710class SparseTensorCompressConverter :
public OpConversionPattern<CompressOp> {
712 using OpConversionPattern::OpConversionPattern;
714 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter)
const override {
716 Location loc = op->getLoc();
721 Value values = adaptor.getValues();
722 Value filled = adaptor.getFilled();
723 Value added = adaptor.getAdded();
724 Value count = adaptor.getCount();
725 Value tensor = adaptor.getTensor();
729 auto lvlCoords =
genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
730 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
733 {tensor, lvlCoords, values, filled, added, count},
735 Operation *parent =
getTop(op);
736 rewriter.setInsertionPointAfter(parent);
737 rewriter.replaceOp(op, adaptor.getTensor());
739 memref::DeallocOp::create(rewriter, loc, values);
740 memref::DeallocOp::create(rewriter, loc, filled);
741 memref::DeallocOp::create(rewriter, loc, added);
747class SparseTensorAssembleConverter :
public OpConversionPattern<AssembleOp> {
749 using OpConversionPattern::OpConversionPattern;
751 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
752 ConversionPatternRewriter &rewriter)
const override {
753 const Location loc = op->getLoc();
755 assert(dstTp.hasStaticDimShape());
756 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
762 NewCallParams(rewriter, loc)
763 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
764 .genNewCall(Action::kPack,
765 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
766 adaptor.getValues()));
767 rewriter.replaceOp(op, dst);
779class SparseTensorDisassembleConverter
780 :
public OpConversionPattern<DisassembleOp> {
782 using OpConversionPattern::OpConversionPattern;
784 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter)
const override {
786 Location loc = op->getLoc();
788 SmallVector<Value> retVal;
789 SmallVector<Value> retLen;
792 Level trailCOOLen = 0;
793 for (
Level l = 0; l < lvlRank; l++) {
800 trailCOOLen = lvlRank - l;
805 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
807 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
808 retVal.push_back(poss);
813 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
815 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
816 retVal.push_back(crds);
821 if (trailCOOLen != 0) {
822 uint64_t cooStartLvl = lvlRank - trailCOOLen;
827 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
830 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
831 retVal.push_back(poss);
836 auto buf =
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
837 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
839 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
843 auto bufLen = arith::MulIOp::create(rewriter, loc, crdLen, two);
844 Type indexType = rewriter.getIndexType();
847 scf::ForOp forOp = scf::ForOp::create(rewriter, loc, zero, crdLen, one);
848 auto idx = forOp.getInductionVar();
849 rewriter.setInsertionPointToStart(forOp.getBody());
850 auto c0 = memref::LoadOp::create(rewriter, loc, crds0, idx);
851 auto c1 = memref::LoadOp::create(rewriter, loc, crds1, idx);
852 SmallVector<Value> args;
854 args.push_back(zero);
855 memref::StoreOp::create(rewriter, loc, c0, buf, args);
857 memref::StoreOp::create(rewriter, loc, c1, buf, args);
858 rewriter.setInsertionPointAfter(forOp);
859 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
860 retVal.push_back(buf);
864 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
865 auto valLenTp = op.getValLen().getType();
867 retVal.push_back(vals);
871 assert(retVal.size() + retLen.size() == op.getNumResults());
872 for (
unsigned i = 0, sz = retVal.size(); i < sz; i++) {
873 auto tensor = bufferization::ToTensorOp::create(
877 tensor::CastOp::create(rewriter, loc, op.getResultTypes()[i], tensor);
881 retVal.append(retLen.begin(), retLen.end());
882 rewriter.replaceOp(op, retVal);
887struct SparseHasRuntimeLibraryConverter
888 :
public OpConversionPattern<HasRuntimeLibraryOp> {
889 using OpConversionPattern::OpConversionPattern;
891 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
892 ConversionPatternRewriter &rewriter)
const override {
893 auto i1Type = rewriter.getI1Type();
894 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
895 op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
907 addConversion([](
Type type) {
return type; });
908 addConversion(convertSparseTensorTypes);
920 .
add<SparseReturnConverter, SparseTensorLvlOpConverter,
921 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
922 SparseTensorAllocConverter, SparseTensorEmptyConverter,
923 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
924 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
925 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
926 SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
927 SparseTensorInsertConverter, SparseTensorExpandConverter,
928 SparseTensorCompressConverter, SparseTensorAssembleConverter,
929 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
SparseTensorTypeToPtrConverter()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool isWithPos(Level l) const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool isCompressedLvl(Level l) const
bool isWithCrd(Level l) const
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Action
The actions performed by @newSparseTensor.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
bool isValidPrimaryType(Type elemTp)
Returns true if the given type is a valid sparse tensor element type supported by the runtime library...
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.