43static std::optional<Type> convertSparseTensorTypes(
Type type) {
45 return LLVM::LLVMPointerType::get(type.
getContext());
54 StringRef name =
"sparseLvlSize";
66 StringRef name =
"sparseDimSize";
89 if (ShapedType::isStatic(sz))
95 return genLvlSizeCall(builder, loc,
tensor, lvl);
106 if (ShapedType::isStatic(sz))
109 return genDimSizeCall(builder, loc,
tensor, dim);
118 out.reserve(dimRank);
120 out.push_back(createOrFoldDimCall(builder, loc, stt,
tensor, d));
130 fillDimSizes(builder, loc, stt,
tensor, out);
139 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
140 return memref::AllocOp::create(rewriter, loc, memTp,
ValueRange{sz});
148 for (
const auto lt : stt.
getEncoding().getLvlTypes())
157 return memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, buf);
164 lvlBarePtrs.reserve(lvlTensors.size() + 1);
166 for (
const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171 Value idxPtr = memref::ExtractAlignedPointerAsIndexOp::create(
174 arith::IndexCastOp::create(builder, loc, builder.
getI64Type(), idxPtr);
183class NewCallParams final {
186 NewCallParams(OpBuilder &builder, Location loc)
192 NewCallParams &
genBuffers(SparseTensorType stt,
193 ArrayRef<Value> dimSizesValues,
194 Value dimSizesBuffer = Value()) {
195 assert(dimSizesValues.size() ==
static_cast<size_t>(stt.
getDimRank()));
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
199 params[kParamDimSizes] = dimSizesBuffer
202 SmallVector<Value> lvlSizesValues;
204 builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
210 params[kParamValTp] =
217 bool isInitialized()
const {
218 for (
unsigned i = 0; i < kNumStaticParams; ++i)
226 Value genNewCall(
Action action, Value ptr = Value()) {
227 assert(isInitialized() &&
"Must initialize before genNewCall");
228 StringRef name =
"newSparseTensor";
230 params[kParamPtr] = ptr ? ptr : LLVM::ZeroOp::create(builder, loc, pTp);
231 return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
236 static constexpr unsigned kNumStaticParams = 8;
237 static constexpr unsigned kNumDynamicParams = 2;
238 static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr unsigned kParamDimSizes = 0;
240 static constexpr unsigned kParamLvlSizes = 1;
241 static constexpr unsigned kParamLvlTypes = 2;
242 static constexpr unsigned kParamDim2Lvl = 3;
243 static constexpr unsigned kParamLvl2Dim = 4;
244 static constexpr unsigned kParamPosTp = 5;
245 static constexpr unsigned kParamCrdTp = 6;
246 static constexpr unsigned kParamValTp = 7;
247 static constexpr unsigned kParamAction = 8;
248 static constexpr unsigned kParamPtr = 9;
253 Value params[kNumParams];
260 auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
270 auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
282 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
295 auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
309class SparseReturnConverter :
public OpConversionPattern<func::ReturnOp> {
311 using OpConversionPattern::OpConversionPattern;
313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
315 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
321class SparseTensorLvlOpConverter :
public OpConversionPattern<LvlOp> {
323 using OpConversionPattern::OpConversionPattern;
325 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter)
const override {
333 std::optional<int64_t> lvl = op.getConstantLvlIndex();
340 Value src = adaptor.getOperands()[0];
341 rewriter.replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
347class SparseCastConverter :
public OpConversionPattern<tensor::CastOp> {
349 using OpConversionPattern::OpConversionPattern;
351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter)
const override {
356 if (!encDst || encDst != encSrc)
358 rewriter.replaceOp(op, adaptor.getOperands());
363class SparseReMapConverter :
public OpConversionPattern<ReinterpretMapOp> {
365 using OpConversionPattern::OpConversionPattern;
367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter)
const override {
370 rewriter.replaceOp(op, adaptor.getSource());
376class SparseTensorNewConverter :
public OpConversionPattern<NewOp> {
378 using OpConversionPattern::OpConversionPattern;
380 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
381 ConversionPatternRewriter &rewriter)
const override {
382 Location loc = op.getLoc();
387 SmallVector<Value> dimSizesValues;
388 Value dimSizesBuffer;
389 Value reader =
genReader(rewriter, loc, stt, adaptor.getOperands()[0],
390 dimSizesValues, dimSizesBuffer);
392 Value tensor = NewCallParams(rewriter, loc)
393 .genBuffers(stt, dimSizesValues, dimSizesBuffer)
394 .genNewCall(Action::kFromReader, reader);
396 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
397 EmitCInterface::Off);
398 rewriter.replaceOp(op, tensor);
405class SparseTensorAllocConverter
406 :
public OpConversionPattern<bufferization::AllocTensorOp> {
408 using OpConversionPattern::OpConversionPattern;
410 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter)
const override {
416 return rewriter.notifyMatchFailure(op,
"alloc copy not implemented");
418 Location loc = op.getLoc();
420 SmallVector<Value> dimSizesValues;
421 dimSizesValues.reserve(dimRank);
422 unsigned operandCtr = 0;
423 for (
Dimension d = 0; d < dimRank; d++) {
424 dimSizesValues.push_back(
426 ? adaptor.getOperands()[operandCtr++]
431 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
433 .genNewCall(Action::kEmpty));
439class SparseTensorEmptyConverter :
public OpConversionPattern<tensor::EmptyOp> {
441 using OpConversionPattern::OpConversionPattern;
443 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
444 ConversionPatternRewriter &rewriter)
const override {
445 Location loc = op.getLoc();
451 SmallVector<Value> dimSizesValues;
452 dimSizesValues.reserve(dimRank);
453 auto shape = op.getType().getShape();
454 unsigned operandCtr = 0;
455 for (
Dimension d = 0; d < dimRank; d++) {
457 ? adaptor.getOperands()[operandCtr++]
462 rewriter.replaceOp(op, NewCallParams(rewriter, loc)
464 .genNewCall(Action::kEmpty));
470class SparseTensorReorderCOOConverter
471 :
public OpConversionPattern<ReorderCOOOp> {
473 using OpConversionPattern::OpConversionPattern;
476 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter)
const override {
478 const Location loc = op->getLoc();
482 const Value src = adaptor.getInputCoo();
484 NewCallParams params(rewriter, loc);
485 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
486 rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
487 .genNewCall(Action::kSortCOOInPlace, src));
494class SparseTensorDeallocConverter
495 :
public OpConversionPattern<bufferization::DeallocTensorOp> {
497 using OpConversionPattern::OpConversionPattern;
499 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
500 ConversionPatternRewriter &rewriter)
const override {
503 StringRef name =
"delSparseTensor";
504 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
505 EmitCInterface::Off);
506 rewriter.eraseOp(op);
512class SparseTensorToPositionsConverter
513 :
public OpConversionPattern<ToPositionsOp> {
515 using OpConversionPattern::OpConversionPattern;
517 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
518 ConversionPatternRewriter &rewriter)
const override {
520 auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521 adaptor.getTensor(), op.getLevel());
522 rewriter.replaceOp(op, poss);
528class SparseTensorToCoordinatesConverter
529 :
public OpConversionPattern<ToCoordinatesOp> {
531 using OpConversionPattern::OpConversionPattern;
533 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
534 ConversionPatternRewriter &rewriter)
const override {
535 const Location loc = op.getLoc();
537 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
541 if (op.getType() != crds.getType())
542 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
543 rewriter.replaceOp(op, crds);
549class SparseToCoordinatesBufferConverter
550 :
public OpConversionPattern<ToCoordinatesBufferOp> {
552 using OpConversionPattern::OpConversionPattern;
554 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
555 ConversionPatternRewriter &rewriter)
const override {
556 const Location loc = op.getLoc();
558 auto crds = genCoordinatesBufferCall(
562 if (op.getType() != crds.getType())
563 crds = memref::CastOp::create(rewriter, loc, op.getType(), crds);
564 rewriter.replaceOp(op, crds);
570class SparseTensorToValuesConverter :
public OpConversionPattern<ToValuesOp> {
572 using OpConversionPattern::OpConversionPattern;
574 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
575 ConversionPatternRewriter &rewriter)
const override {
577 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
578 rewriter.replaceOp(op, vals);
584class SparseNumberOfEntriesConverter
585 :
public OpConversionPattern<NumberOfEntriesOp> {
587 using OpConversionPattern::OpConversionPattern;
589 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
590 ConversionPatternRewriter &rewriter)
const override {
593 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
595 rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
601class SparseTensorLoadConverter :
public OpConversionPattern<LoadOp> {
603 using OpConversionPattern::OpConversionPattern;
605 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
606 ConversionPatternRewriter &rewriter)
const override {
607 if (op.getHasInserts()) {
609 StringRef name =
"endLexInsert";
610 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
611 EmitCInterface::Off);
613 rewriter.replaceOp(op, adaptor.getOperands());
619class SparseTensorInsertConverter
620 :
public OpConversionPattern<tensor::InsertOp> {
622 using OpConversionPattern::OpConversionPattern;
624 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
625 ConversionPatternRewriter &rewriter)
const override {
629 Location loc = op->getLoc();
636 assert(stt.
isIdentity() &&
"Run reinterpret-map before conversion.");
639 Value lvlCoords, vref;
641 OpBuilder::InsertionGuard guard(rewriter);
642 Operation *loop = op;
647 if (llvm::isa<LoopLikeOpInterface>(loop)) {
649 rewriter.setInsertionPoint(loop);
651 lvlCoords =
genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
654 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
655 memref::StoreOp::create(rewriter, loc, adaptor.getScalar(), vref);
658 {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
659 rewriter.replaceOp(op, adaptor.getDest());
665class SparseTensorExpandConverter :
public OpConversionPattern<ExpandOp> {
667 using OpConversionPattern::OpConversionPattern;
669 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter)
const override {
671 Location loc = op->getLoc();
673 Type eltType = srcTp.getElementType();
674 Type boolType = rewriter.getIntegerType(1);
675 Type idxType = rewriter.getIndexType();
677 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
679 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680 srcTp.getLvlRank() - 1);
684 Value values = genAlloc(rewriter, loc, sz, eltType);
685 Value filled = genAlloc(rewriter, loc, sz, boolType);
686 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
693 linalg::FillOp::create(rewriter, loc,
696 linalg::FillOp::create(rewriter, loc,
700 assert(op.getNumResults() == 4);
701 rewriter.replaceOp(op, {values, filled, lastLvlCoordinates, zero});
707class SparseTensorCompressConverter :
public OpConversionPattern<CompressOp> {
709 using OpConversionPattern::OpConversionPattern;
711 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter)
const override {
713 Location loc = op->getLoc();
718 Value values = adaptor.getValues();
719 Value filled = adaptor.getFilled();
720 Value added = adaptor.getAdded();
721 Value count = adaptor.getCount();
722 Value tensor = adaptor.getTensor();
726 auto lvlCoords =
genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
727 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
730 {tensor, lvlCoords, values, filled, added, count},
732 Operation *parent =
getTop(op);
733 rewriter.setInsertionPointAfter(parent);
734 rewriter.replaceOp(op, adaptor.getTensor());
736 memref::DeallocOp::create(rewriter, loc, values);
737 memref::DeallocOp::create(rewriter, loc, filled);
738 memref::DeallocOp::create(rewriter, loc, added);
744class SparseTensorAssembleConverter :
public OpConversionPattern<AssembleOp> {
746 using OpConversionPattern::OpConversionPattern;
748 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
749 ConversionPatternRewriter &rewriter)
const override {
750 const Location loc = op->getLoc();
752 assert(dstTp.hasStaticDimShape());
753 SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
759 NewCallParams(rewriter, loc)
760 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
761 .genNewCall(Action::kPack,
762 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763 adaptor.getValues()));
764 rewriter.replaceOp(op, dst);
776class SparseTensorDisassembleConverter
777 :
public OpConversionPattern<DisassembleOp> {
779 using OpConversionPattern::OpConversionPattern;
781 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
782 ConversionPatternRewriter &rewriter)
const override {
783 Location loc = op->getLoc();
785 SmallVector<Value> retVal;
786 SmallVector<Value> retLen;
789 Level trailCOOLen = 0;
790 for (
Level l = 0; l < lvlRank; l++) {
797 trailCOOLen = lvlRank - l;
802 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
804 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805 retVal.push_back(poss);
810 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
812 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813 retVal.push_back(crds);
818 if (trailCOOLen != 0) {
819 uint64_t cooStartLvl = lvlRank - trailCOOLen;
824 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
827 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828 retVal.push_back(poss);
833 auto buf =
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
836 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
840 auto bufLen = arith::MulIOp::create(rewriter, loc, crdLen, two);
841 Type indexType = rewriter.getIndexType();
844 scf::ForOp forOp = scf::ForOp::create(rewriter, loc, zero, crdLen, one);
845 auto idx = forOp.getInductionVar();
846 rewriter.setInsertionPointToStart(forOp.getBody());
847 auto c0 = memref::LoadOp::create(rewriter, loc, crds0, idx);
848 auto c1 = memref::LoadOp::create(rewriter, loc, crds1, idx);
849 SmallVector<Value> args;
851 args.push_back(zero);
852 memref::StoreOp::create(rewriter, loc, c0, buf, args);
854 memref::StoreOp::create(rewriter, loc, c1, buf, args);
855 rewriter.setInsertionPointAfter(forOp);
856 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857 retVal.push_back(buf);
861 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862 auto valLenTp = op.getValLen().getType();
864 retVal.push_back(vals);
868 assert(retVal.size() + retLen.size() == op.getNumResults());
869 for (
unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870 auto tensor = bufferization::ToTensorOp::create(
874 tensor::CastOp::create(rewriter, loc, op.getResultTypes()[i], tensor);
878 retVal.append(retLen.begin(), retLen.end());
879 rewriter.replaceOp(op, retVal);
884struct SparseHasRuntimeLibraryConverter
885 :
public OpConversionPattern<HasRuntimeLibraryOp> {
886 using OpConversionPattern::OpConversionPattern;
888 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
889 ConversionPatternRewriter &rewriter)
const override {
890 auto i1Type = rewriter.getI1Type();
891 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
892 op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
904 addConversion([](
Type type) {
return type; });
905 addConversion(convertSparseTensorTypes);
917 .add<SparseReturnConverter, SparseTensorLvlOpConverter,
918 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
919 SparseTensorAllocConverter, SparseTensorEmptyConverter,
920 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
921 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
922 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
923 SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
924 SparseTensorInsertConverter, SparseTensorExpandConverter,
925 SparseTensorCompressConverter, SparseTensorAssembleConverter,
926 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
927 typeConverter,
patterns.getContext());
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
SparseTensorTypeToPtrConverter()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool isWithPos(Level l) const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool isCompressedLvl(Level l) const
bool isWithCrd(Level l) const
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Action
The actions performed by @newSparseTensor.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
uint64_t Level
The type of level identifiers and level-ranks.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
const FrozenRewritePatternSet & patterns