43 static std::optional<Type> convertSparseTensorTypes(
Type type) {
54 StringRef name =
"sparseLvlSize";
66 StringRef name =
"sparseDimSize";
89 if (!ShapedType::isDynamic(sz))
95 return genLvlSizeCall(builder, loc, tensor, lvl);
106 if (!ShapedType::isDynamic(sz))
109 return genDimSizeCall(builder, loc, tensor, dim);
118 out.reserve(dimRank);
120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 fillDimSizes(builder, loc, stt, tensor, out);
148 for (
const auto lt : stt.
getEncoding().getLvlTypes())
157 return builder.
create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
164 lvlBarePtrs.reserve(lvlTensors.size() + 1);
166 for (
const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171 Value idxPtr = builder.
create<memref::ExtractAlignedPointerAsIndexOp>(
183 class NewCallParams final {
195 assert(dimSizesValues.size() ==
static_cast<size_t>(stt.
getDimRank()));
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
199 params[kParamDimSizes] = dimSizesBuffer
204 builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
210 params[kParamValTp] =
217 bool isInitialized()
const {
218 for (
unsigned i = 0; i < kNumStaticParams; ++i)
227 assert(isInitialized() &&
"Must initialize before genNewCall");
228 StringRef name =
"newSparseTensor";
230 params[kParamPtr] = ptr ? ptr : builder.
create<LLVM::ZeroOp>(loc, pTp);
236 static constexpr
unsigned kNumStaticParams = 8;
237 static constexpr
unsigned kNumDynamicParams = 2;
238 static constexpr
unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr
unsigned kParamDimSizes = 0;
240 static constexpr
unsigned kParamLvlSizes = 1;
241 static constexpr
unsigned kParamLvlTypes = 2;
242 static constexpr
unsigned kParamDim2Lvl = 3;
243 static constexpr
unsigned kParamLvl2Dim = 4;
244 static constexpr
unsigned kParamPosTp = 5;
245 static constexpr
unsigned kParamCrdTp = 6;
246 static constexpr
unsigned kParamValTp = 7;
247 static constexpr
unsigned kParamAction = 8;
248 static constexpr
unsigned kParamPtr = 9;
253 Value params[kNumParams];
313 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
325 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
333 std::optional<int64_t> lvl = op.getConstantLvlIndex();
340 Value src = adaptor.getOperands()[0];
341 rewriter.
replaceOp(op, genLvlSizeCall(rewriter, op.getLoc(), src, *lvl));
351 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
356 if (!encDst || encDst != encSrc)
358 rewriter.
replaceOp(op, adaptor.getOperands());
367 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
370 rewriter.
replaceOp(op, adaptor.getSource());
380 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
388 Value dimSizesBuffer;
389 Value reader =
genReader(rewriter, loc, stt, adaptor.getOperands()[0],
390 dimSizesValues, dimSizesBuffer);
392 Value tensor = NewCallParams(rewriter, loc)
393 .genBuffers(stt, dimSizesValues, dimSizesBuffer)
396 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
405 class SparseTensorAllocConverter
410 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
421 dimSizesValues.reserve(dimRank);
422 unsigned operandCtr = 0;
423 for (
Dimension d = 0; d < dimRank; d++) {
424 dimSizesValues.push_back(
426 ? adaptor.getOperands()[operandCtr++]
431 rewriter.
replaceOp(op, NewCallParams(rewriter, loc)
443 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
452 dimSizesValues.reserve(dimRank);
453 auto shape = op.getType().getShape();
454 unsigned operandCtr = 0;
455 for (
Dimension d = 0; d < dimRank; d++) {
457 ? adaptor.getOperands()[operandCtr++]
462 rewriter.
replaceOp(op, NewCallParams(rewriter, loc)
470 class SparseTensorReorderCOOConverter
476 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
482 const Value src = adaptor.getInputCoo();
484 NewCallParams params(rewriter, loc);
486 rewriter.
replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
494 class SparseTensorDeallocConverter
499 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
503 StringRef name =
"delSparseTensor";
504 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
512 class SparseTensorToPositionsConverter
517 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
520 auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
521 adaptor.getTensor(), op.getLevel());
528 class SparseTensorToCoordinatesConverter
533 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
537 auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
541 if (op.getType() != crds.getType())
542 crds = rewriter.
create<memref::CastOp>(loc, op.getType(), crds);
549 class SparseToCoordinatesBufferConverter
554 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
558 auto crds = genCoordinatesBufferCall(
562 if (op.getType() != crds.getType())
563 crds = rewriter.
create<memref::CastOp>(loc, op.getType(), crds);
574 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
577 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
584 class SparseNumberOfEntriesConverter
589 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
593 auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
605 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
607 if (op.getHasInserts()) {
609 StringRef name =
"endLexInsert";
610 createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(),
613 rewriter.
replaceOp(op, adaptor.getOperands());
619 class SparseTensorInsertConverter
624 matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
636 assert(stt.
isIdentity() &&
"Run reinterpret-map before conversion.");
639 Value lvlCoords, vref;
647 if (llvm::isa<LoopLikeOpInterface>(loop)) {
654 storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
655 rewriter.
create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
659 rewriter.
replaceOp(op, adaptor.getDest());
669 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
673 Type eltType = srcTp.getElementType();
679 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
680 srcTp.getLvlRank() - 1);
684 Value values = genAlloc(rewriter, loc, sz, eltType);
685 Value filled = genAlloc(rewriter, loc, sz, boolType);
686 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
693 rewriter.
create<linalg::FillOp>(
696 rewriter.
create<linalg::FillOp>(
700 assert(op.getNumResults() == 4);
701 rewriter.
replaceOp(op, {values, filled, lastLvlCoordinates, zero});
711 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
718 Value values = adaptor.getValues();
719 Value filled = adaptor.getFilled();
720 Value added = adaptor.getAdded();
721 Value count = adaptor.getCount();
722 Value tensor = adaptor.getTensor();
727 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
730 {tensor, lvlCoords, values, filled, added, count},
732 rewriter.
replaceOp(op, adaptor.getTensor());
736 rewriter.
create<memref::DeallocOp>(loc, values);
737 rewriter.
create<memref::DeallocOp>(loc, filled);
738 rewriter.
create<memref::DeallocOp>(loc, added);
748 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
752 assert(dstTp.hasStaticDimShape());
759 NewCallParams(rewriter, loc)
760 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
762 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
763 adaptor.getValues()));
776 class SparseTensorDisassembleConverter
781 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
789 Level trailCOOLen = 0;
790 for (
Level l = 0; l < lvlRank; l++) {
797 trailCOOLen = lvlRank - l;
802 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
804 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
805 retVal.push_back(poss);
810 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
812 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
813 retVal.push_back(crds);
818 if (trailCOOLen != 0) {
819 uint64_t cooStartLvl = lvlRank - trailCOOLen;
824 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
827 auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
828 retVal.push_back(poss);
833 auto buf =
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
834 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
836 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
840 auto bufLen = rewriter.
create<arith::MulIOp>(loc, crdLen, two);
844 scf::ForOp forOp = rewriter.
create<scf::ForOp>(loc, zero, crdLen, one);
845 auto idx = forOp.getInductionVar();
847 auto c0 = rewriter.
create<memref::LoadOp>(loc, crds0, idx);
848 auto c1 = rewriter.
create<memref::LoadOp>(loc, crds1, idx);
851 args.push_back(zero);
852 rewriter.
create<memref::StoreOp>(loc, c0, buf, args);
854 rewriter.
create<memref::StoreOp>(loc, c1, buf, args);
856 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
857 retVal.push_back(buf);
861 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
862 auto valLenTp = op.getValLen().getType();
864 retVal.push_back(vals);
868 assert(retVal.size() + retLen.size() == op.getNumResults());
869 for (
unsigned i = 0, sz = retVal.size(); i < sz; i++) {
870 auto tensor = rewriter.
create<bufferization::ToTensorOp>(loc, retVal[i]);
872 rewriter.
create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
876 retVal.append(retLen.begin(), retLen.end());
882 struct SparseHasRuntimeLibraryConverter
886 matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor,
902 addConversion([](
Type type) {
return type; });
903 addConversion(convertSparseTensorTypes);
915 .add<SparseReturnConverter, SparseTensorLvlOpConverter,
916 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
917 SparseTensorAllocConverter, SparseTensorEmptyConverter,
918 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
919 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
920 SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
921 SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
922 SparseTensorInsertConverter, SparseTensorExpandConverter,
923 SparseTensorCompressConverter, SparseTensorAssembleConverter,
924 SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
925 typeConverter,
patterns.getContext());
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
SparseTensorTypeToPtrConverter()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool isWithPos(Level l) const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool isCompressedLvl(Level l) const
bool isWithCrd(Level l) const
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
Level getAoSCOOStart() const
Returns the starting level of this sparse tensor type for a trailing COO region that spans at least t...
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?...
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Action
The actions performed by @newSparseTensor.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...