43 static std::optional<Type> convertSparseTensorTypes(
Type type) {
54 StringRef name =
"sparseLvlSize";
66 StringRef name =
"sparseDimSize";
89 if (!ShapedType::isDynamic(sz))
95 return genLvlSizeCall(builder, loc, tensor, lvl);
106 if (!ShapedType::isDynamic(sz))
109 return genDimSizeCall(builder, loc, tensor, dim);
118 out.reserve(dimRank);
120 out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d));
130 fillDimSizes(builder, loc, stt, tensor, out);
148 for (
const auto lt : stt.
getEncoding().getLvlTypes())
157 return builder.
create<memref::ExtractAlignedPointerAsIndexOp>(loc, buf);
164 lvlBarePtrs.reserve(lvlTensors.size() + 1);
166 for (
const auto lvl : lvlTensors)
167 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, lvl));
170 lvlBarePtrs.push_back(extractBarePtrFromTensor(builder, loc, valTensor));
171 Value idxPtr = builder.
create<memref::ExtractAlignedPointerAsIndexOp>(
183 class NewCallParams final {
195 assert(dimSizesValues.size() ==
static_cast<size_t>(stt.
getDimRank()));
197 params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
199 params[kParamDimSizes] = dimSizesBuffer
204 builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205 lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
210 params[kParamValTp] =
217 bool isInitialized()
const {
218 for (
unsigned i = 0; i < kNumStaticParams; ++i)
227 assert(isInitialized() &&
"Must initialize before genNewCall");
228 StringRef name =
"newSparseTensor";
230 params[kParamPtr] = ptr ? ptr : builder.
create<LLVM::ZeroOp>(loc, pTp);
236 static constexpr
unsigned kNumStaticParams = 8;
237 static constexpr
unsigned kNumDynamicParams = 2;
238 static constexpr
unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
239 static constexpr
unsigned kParamDimSizes = 0;
240 static constexpr
unsigned kParamLvlSizes = 1;
241 static constexpr
unsigned kParamLvlTypes = 2;
242 static constexpr
unsigned kParamDim2Lvl = 3;
243 static constexpr
unsigned kParamLvl2Dim = 4;
244 static constexpr
unsigned kParamPosTp = 5;
245 static constexpr
unsigned kParamCrdTp = 6;
246 static constexpr
unsigned kParamValTp = 7;
247 static constexpr
unsigned kParamAction = 8;
248 static constexpr
unsigned kParamPtr = 9;
253 Value params[kNumParams];
299 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
311 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
319 std::optional<int64_t> lvl = op.getConstantLvlIndex();
326 Value src = adaptor.getOperands()[0];
327 rewriter.
replaceOp(op, genLvlSizeCall(rewriter, op.
getLoc(), src, *lvl));
337 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
342 if (!encDst || encDst != encSrc)
344 rewriter.
replaceOp(op, adaptor.getOperands());
353 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
356 rewriter.
replaceOp(op, adaptor.getSource());
366 matchAndRewrite(
NewOp op, OpAdaptor adaptor,
374 Value dimSizesBuffer;
375 Value reader =
genReader(rewriter, loc, stt, adaptor.getOperands()[0],
376 dimSizesValues, dimSizesBuffer);
378 Value tensor = NewCallParams(rewriter, loc)
379 .genBuffers(stt, dimSizesValues, dimSizesBuffer)
382 createFuncCall(rewriter, loc,
"delSparseTensorReader", {}, {reader},
391 class SparseTensorAllocConverter
396 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
407 dimSizesValues.reserve(dimRank);
408 unsigned operandCtr = 0;
409 for (
Dimension d = 0; d < dimRank; d++) {
410 dimSizesValues.push_back(
412 ? adaptor.getOperands()[operandCtr++]
417 rewriter.
replaceOp(op, NewCallParams(rewriter, loc)
429 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
438 dimSizesValues.reserve(dimRank);
439 auto shape = op.getType().getShape();
440 unsigned operandCtr = 0;
441 for (
Dimension d = 0; d < dimRank; d++) {
443 ? adaptor.getOperands()[operandCtr++]
448 rewriter.
replaceOp(op, NewCallParams(rewriter, loc)
456 class SparseTensorReorderCOOConverter
462 matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
468 const Value src = adaptor.getInputCoo();
470 NewCallParams params(rewriter, loc);
472 rewriter.
replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
480 class SparseTensorDeallocConverter
485 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
489 StringRef name =
"delSparseTensor";
498 class SparseTensorToPositionsConverter
503 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
506 auto poss = genPositionsCall(rewriter, op.
getLoc(), stt,
507 adaptor.getTensor(), op.getLevel());
514 class SparseTensorToCoordinatesConverter
519 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
522 auto crds = genCoordinatesCall(rewriter, op.
getLoc(), stt,
523 adaptor.getTensor(), op.getLevel());
526 if (op.getType() != crds.getType())
527 crds = rewriter.
create<memref::CastOp>(op.
getLoc(), op.getType(), crds);
538 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
541 auto vals = genValuesCall(rewriter, op.
getLoc(), stt, adaptor.getTensor());
548 class SparseNumberOfEntriesConverter
553 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
557 auto vals = genValuesCall(rewriter, op.
getLoc(), stt, adaptor.getTensor());
569 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
571 if (op.getHasInserts()) {
573 StringRef name =
"endLexInsert";
577 rewriter.
replaceOp(op, adaptor.getOperands());
587 matchAndRewrite(InsertOp op, OpAdaptor adaptor,
596 Value lvlCoords, vref;
604 if (llvm::isa<LoopLikeOpInterface>(loop)) {
611 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
612 rewriter.
create<memref::StoreOp>(loc, adaptor.getValue(), vref);
616 rewriter.
replaceOp(op, adaptor.getTensor());
626 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
630 Type eltType = srcTp.getElementType();
636 Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(),
637 srcTp.getLvlRank() - 1);
641 Value values = genAlloc(rewriter, loc, sz, eltType);
642 Value filled = genAlloc(rewriter, loc, sz, boolType);
643 Value lastLvlCoordinates = genAlloc(rewriter, loc, sz, idxType);
650 rewriter.
create<linalg::FillOp>(
653 rewriter.
create<linalg::FillOp>(
658 rewriter.
replaceOp(op, {values, filled, lastLvlCoordinates, zero});
668 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
675 Value values = adaptor.getValues();
676 Value filled = adaptor.getFilled();
677 Value added = adaptor.getAdded();
678 Value count = adaptor.getCount();
679 Value tensor = adaptor.getTensor();
684 storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
687 {tensor, lvlCoords, values, filled, added, count},
689 rewriter.
replaceOp(op, adaptor.getTensor());
693 rewriter.
create<memref::DeallocOp>(loc, values);
694 rewriter.
create<memref::DeallocOp>(loc, filled);
695 rewriter.
create<memref::DeallocOp>(loc, added);
705 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
709 assert(dstTp.hasStaticDimShape());
716 NewCallParams(rewriter, loc)
717 .genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
719 genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
720 adaptor.getValues()));
727 class SparseTensorDisassembleConverter
732 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
742 auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
743 auto valLenTp = op.getValLen().getType();
745 retVal.push_back(vals);
749 Level trailCOOLen = 0;
750 for (
Level l = 0; l < lvlRank; l++) {
757 trailCOOLen = lvlRank - l;
762 genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
764 auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
765 retVal.push_back(poss);
770 genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
772 auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
773 retVal.push_back(crds);
778 if (trailCOOLen != 0) {
779 uint64_t cooStartLvl = lvlRank - trailCOOLen;
784 auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
787 auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
788 retVal.push_back(poss);
794 genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
795 auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
797 auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
801 auto bufLen = rewriter.
create<arith::MulIOp>(loc, crdLen, two);
805 scf::ForOp forOp = rewriter.
create<scf::ForOp>(loc, zero, crdLen, one);
806 auto idx = forOp.getInductionVar();
808 auto c0 = rewriter.
create<memref::LoadOp>(loc, crds0, idx);
809 auto c1 = rewriter.
create<memref::LoadOp>(loc, crds1, idx);
812 args.push_back(zero);
813 rewriter.
create<memref::StoreOp>(loc, c0, buf, args);
815 rewriter.
create<memref::StoreOp>(loc, c1, buf, args);
817 auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
818 retVal.push_back(buf);
823 for (
unsigned i = 0, sz = retVal.size(); i < sz; i++) {
824 auto tensor = rewriter.
create<bufferization::ToTensorOp>(loc, retVal[i]);
829 retVal.append(retLen.begin(), retLen.end());
842 addConversion([](
Type type) {
return type; });
843 addConversion(convertSparseTensorTypes);
855 .
add<SparseReturnConverter, SparseTensorLvlOpConverter,
856 SparseCastConverter, SparseReMapConverter, SparseTensorNewConverter,
857 SparseTensorAllocConverter, SparseTensorEmptyConverter,
858 SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
859 SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
860 SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
861 SparseTensorLoadConverter, SparseTensorInsertConverter,
862 SparseTensorExpandConverter, SparseTensorCompressConverter,
863 SparseTensorAssembleConverter, SparseTensorDisassembleConverter>(
static void genBuffers(CodegenEnv &env, OpBuilder &builder)
Local bufferization of all dense and sparse data structures.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
SparseTensorTypeToPtrConverter()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A wrapper around RankedTensorType, which has three goals:
Size getDynamicDimSize(Dimension d) const
Safely looks up the requested dimension-DynSize.
bool isWithPos(Level l) const
Type getElementType() const
bool isLooseCompressedLvl(Level l) const
bool hasEncoding() const
Returns true for tensors which have an encoding, and false for those which do not.
Dimension getDimRank() const
Returns the dimension-rank.
Type getCrdType() const
Returns the coordinate-overhead MLIR type, defaulting to IndexType.
bool isIdentity() const
Returns true if the dimToLvl mapping is the identity.
bool isCompressedLvl(Level l) const
bool isWithCrd(Level l) const
Level getLvlRank() const
Returns the level-rank.
SparseTensorEncodingAttr getEncoding() const
bool isDynamicDim(Dimension d) const
Returns true if the given dimension has dynamic size.
AffineMap getDimToLvl() const
Returns the dimToLvl mapping (or the null-map for the identity).
Type getPosType() const
Returns the position-overhead MLIR type, defaulting to IndexType.
bool isUniqueLvl(Level l) const
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp)
Generates an uninitialized temporary buffer with room for one value of the given type,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt)
Generates a constant of the internal dimension level type encoding.
Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values)
Generates a temporary buffer, initializes it with the given contents, and returns it as type memref<?...
Value constantPosTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for position overhead storage.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Action
The actions performed by @newSparseTensor.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
TypedValue< BaseMemRefType > genToMemref(OpBuilder &builder, Location loc, Value tensor)
Value constantAction(OpBuilder &builder, Location loc, Action action)
Generates a constant of the given Action.
int64_t Size
The type for individual components of a compile-time shape, including the value ShapedType::kDynamic ...
Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc)
Generates a constant of the internal type-encoding for coordinate overhead storage.
StringRef overheadTypeFunctionSuffix(OverheadType ot)
Convert OverheadType to its function-name suffix.
Operation * getTop(Operation *op)
Scans to top of generated loop.
Type getOpaquePointerType(MLIRContext *ctx)
Returns the equivalent of void* for opaque arguments to the execution engine.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, ArrayRef< Value > dimSizesValues, Value dimSizesBuffer, SmallVectorImpl< Value > &lvlSizesValues, Value &dim2lvlBuffer, Value &lvl2dimBuffer)
Generates code to set up the buffer parameters for a map.
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl< Value > &dimSizesValues, Value &dimSizesBuffer)
Generates code that opens a reader and sets the dimension sizes.
Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp)
Add conversion from scalar to given type (possibly a 0-rank tensor).
Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp)
Generates an uninitialized temporary buffer of the given size and type, but returns it as type memref...
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface)
Creates a CallOp to the function reference returned by getFunc() in the builder's module.
StringRef primaryTypeFunctionSuffix(PrimaryType pt)
Convert PrimaryType to its function-name suffix.
Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, Type elemTp)
Generates a constant of the internal type-encoding for primary storage.
void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx=0, Value offsetVal=Value())
Stores all the values of vs into the memref mem, which must have rank-1 and size greater-or-equal to ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
Sets up sparse tensor conversion rules.
This class represents an efficient way to signal success or failure.