28 auto enc = tp.getEncoding();
29 const Level lvlRank = enc.getLvlRank();
34 auto sizeType = IntegerType::get(tp.getContext(), 64);
35 auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
36 auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
38 result.push_back(lvlSizes);
39 result.push_back(memSizes);
43 auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
44 auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
46 result.push_back(dimOffset);
47 result.push_back(dimStride);
53static Type convertSpecifier(StorageSpecifierType tp) {
54 return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55 getSpecifierFields(tp));
62constexpr uint64_t kLvlSizePosInSpecifier = 0;
63constexpr uint64_t kMemSizePosInSpecifier = 1;
64constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65constexpr uint64_t kDimStridePosInSpecifier = 3;
69 Value extractField(OpBuilder &builder, Location loc,
70 ArrayRef<int64_t>
indices)
const {
72 LLVM::ExtractValueOp::create(builder, loc, value,
indices),
76 void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t>
indices,
78 value = LLVM::InsertValueOp::create(
84 explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
89 static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
92 Value lvlSize(OpBuilder &builder, Location loc,
Level lvl)
const;
93 void setLvlSize(OpBuilder &builder, Location loc,
Level lvl, Value size);
95 Value dimOffset(OpBuilder &builder, Location loc,
Dimension dim)
const;
96 void setDimOffset(OpBuilder &builder, Location loc,
Dimension dim,
99 Value dimStride(OpBuilder &builder, Location loc,
Dimension dim)
const;
100 void setDimStride(OpBuilder &builder, Location loc,
Dimension dim,
103 Value memSize(OpBuilder &builder, Location loc,
FieldIndex fidx)
const;
104 void setMemSize(OpBuilder &builder, Location loc,
FieldIndex fidx,
107 Value memSizeArray(OpBuilder &builder, Location loc)
const;
108 void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
113 Value metaData = LLVM::PoisonOp::create(builder, loc, structType);
114 SpecifierStructBuilder md(metaData);
116 auto memSizeArrayType =
117 cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118 .getBody()[kMemSizePosInSpecifier]);
120 Value zero =
constantZero(builder, loc, memSizeArrayType.getElementType());
122 for (
int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123 md.setMemSize(builder, loc, i, zero);
126 SpecifierStructBuilder sourceMd(source);
127 md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
133Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
137 ArrayRef<int64_t>{kDimOffsetPosInSpecifier,
static_cast<int64_t
>(dim)});
141void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
145 ArrayRef<int64_t>{kDimOffsetPosInSpecifier,
static_cast<int64_t
>(dim)},
150Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
156 ArrayRef<int64_t>{kLvlSizePosInSpecifier,
static_cast<int64_t
>(lvl)});
160void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
161 Level lvl, Value size) {
166 ArrayRef<int64_t>{kLvlSizePosInSpecifier,
static_cast<int64_t
>(lvl)},
171Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
175 ArrayRef<int64_t>{kDimStridePosInSpecifier,
static_cast<int64_t
>(dim)});
179void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
183 ArrayRef<int64_t>{kDimStridePosInSpecifier,
static_cast<int64_t
>(dim)},
188Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
192 ArrayRef<int64_t>{kMemSizePosInSpecifier,
static_cast<int64_t
>(fidx)});
196void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
200 ArrayRef<int64_t>{kMemSizePosInSpecifier,
static_cast<int64_t
>(fidx)},
205Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
206 Location loc)
const {
207 return LLVM::ExtractValueOp::create(builder, loc, value,
208 kMemSizePosInSpecifier);
212void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
214 value = LLVM::InsertValueOp::create(builder, loc, value, array,
215 kMemSizePosInSpecifier);
225 addConversion([](
Type type) {
return type; });
226 addConversion(convertSpecifier);
233template <
typename Base,
typename SourceOp>
237 using OpConversionPattern<SourceOp>::OpConversionPattern;
241 ConversionPatternRewriter &rewriter)
const override {
242 SpecifierStructBuilder spec(adaptor.getSpecifier());
243 switch (op.getSpecifierKind()) {
244 case StorageSpecifierKind::LvlSize: {
245 Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
246 rewriter.replaceOp(op, v);
249 case StorageSpecifierKind::DimOffset: {
250 Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
251 rewriter.replaceOp(op, v);
254 case StorageSpecifierKind::DimStride: {
255 Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
256 rewriter.replaceOp(op, v);
259 case StorageSpecifierKind::CrdMemSize:
260 case StorageSpecifierKind::PosMemSize:
261 case StorageSpecifierKind::ValMemSize: {
262 auto enc = op.getSpecifier().getType().getEncoding();
264 std::optional<unsigned> lvl;
266 lvl = (*op.getLevel());
269 Value v = Base::onMemSize(rewriter, op, spec, idx);
270 rewriter.replaceOp(op, v);
274 llvm_unreachable(
"unrecognized specifer kind");
280 SetStorageSpecifierOp> {
281 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
284 SpecifierStructBuilder &spec,
Level lvl) {
285 spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
290 SpecifierStructBuilder &spec,
Dimension d) {
291 spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
296 SpecifierStructBuilder &spec,
Dimension d) {
297 spec.setDimStride(builder, op.getLoc(), d, op.getValue());
302 SpecifierStructBuilder &spec,
FieldIndex fidx) {
303 spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
310 GetStorageSpecifierOp> {
311 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
314 SpecifierStructBuilder &spec,
Level lvl) {
315 return spec.lvlSize(builder, op.getLoc(), lvl);
319 const SpecifierStructBuilder &spec,
Dimension d) {
320 return spec.dimOffset(builder, op.getLoc(), d);
324 const SpecifierStructBuilder &spec,
Dimension d) {
325 return spec.dimStride(builder, op.getLoc(), d);
329 SpecifierStructBuilder &spec,
FieldIndex fidx) {
330 return spec.memSize(builder, op.getLoc(), fidx);
335 :
public OpConversionPattern<StorageSpecifierInitOp> {
337 using OpConversionPattern::OpConversionPattern;
340 ConversionPatternRewriter &rewriter)
const override {
341 Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
343 op, SpecifierStructBuilder::getInitValue(
344 rewriter, op.getLoc(), llvmType, adaptor.getSource()));
LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SourceOp::Adaptor OpAdaptor
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
StorageSpecifierToLLVMTypeConverter()
Helper class to produce LLVM dialect operations extracting or inserting values to a struct.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Provides methods to access fields of a sparse tensor with the given encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind)
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc)
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
LogicalResult matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)