18 using namespace sparse_tensor;
28 auto enc = tp.getEncoding();
29 const Level lvlRank = enc.getLvlRank();
38 result.push_back(lvlSizes);
39 result.push_back(memSizes);
46 result.push_back(dimOffset);
47 result.push_back(dimStride);
53 static Type convertSpecifier(StorageSpecifierType tp) {
54 return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55 getSpecifierFields(tp));
62 constexpr uint64_t kLvlSizePosInSpecifier = 0;
63 constexpr uint64_t kMemSizePosInSpecifier = 1;
64 constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65 constexpr uint64_t kDimStridePosInSpecifier = 3;
72 builder.
create<LLVM::ExtractValueOp>(loc, value, indices),
78 value = builder.
create<LLVM::InsertValueOp>(
113 Value metaData = builder.
create<LLVM::UndefOp>(loc, structType);
114 SpecifierStructBuilder md(metaData);
116 auto memSizeArrayType =
117 cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118 .getBody()[kMemSizePosInSpecifier]);
122 for (
int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123 md.setMemSize(builder, loc, i, zero);
126 SpecifierStructBuilder sourceMd(source);
127 md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
207 return builder.
create<LLVM::ExtractValueOp>(loc, value,
208 kMemSizePosInSpecifier);
214 value = builder.
create<LLVM::InsertValueOp>(loc, value, array,
215 kMemSizePosInSpecifier);
225 addConversion([](
Type type) {
return type; });
226 addConversion(convertSpecifier);
233 template <
typename Base,
typename SourceOp>
242 SpecifierStructBuilder spec(adaptor.getSpecifier());
243 switch (op.getSpecifierKind()) {
244 case StorageSpecifierKind::LvlSize: {
245 Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
249 case StorageSpecifierKind::DimOffset: {
250 Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
254 case StorageSpecifierKind::DimStride: {
255 Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
259 case StorageSpecifierKind::CrdMemSize:
260 case StorageSpecifierKind::PosMemSize:
261 case StorageSpecifierKind::ValMemSize: {
262 auto enc = op.getSpecifier().getType().getEncoding();
264 std::optional<unsigned> lvl;
266 lvl = (*op.getLevel());
269 Value v = Base::onMemSize(rewriter, op, spec, idx);
274 llvm_unreachable(
"unrecognized specifer kind");
280 SetStorageSpecifierOp> {
281 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
284 SpecifierStructBuilder &spec,
Level lvl) {
285 spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
290 SpecifierStructBuilder &spec,
Dimension d) {
291 spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
296 SpecifierStructBuilder &spec,
Dimension d) {
297 spec.setDimStride(builder, op.getLoc(), d, op.getValue());
302 SpecifierStructBuilder &spec,
FieldIndex fidx) {
303 spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
310 GetStorageSpecifierOp> {
311 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
314 SpecifierStructBuilder &spec,
Level lvl) {
315 return spec.lvlSize(builder, op.getLoc(), lvl);
319 const SpecifierStructBuilder &spec,
Dimension d) {
320 return spec.dimOffset(builder, op.getLoc(), d);
324 const SpecifierStructBuilder &spec,
Dimension d) {
325 return spec.dimStride(builder, op.getLoc(), d);
329 SpecifierStructBuilder &spec,
FieldIndex fidx) {
330 return spec.memSize(builder, op.getLoc(), fidx);
341 Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
343 op, SpecifierStructBuilder::getInitValue(
344 rewriter, op.getLoc(), llvmType, adaptor.getSource()));
LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
StorageSpecifierToLLVMTypeConverter()
Helper class to produce LLVM dialect operations extracting or inserting values to a struct.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Provides methods to access fields of a sparse tensor with the given encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind)
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc)
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
LogicalResult matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)