28 auto enc = tp.getEncoding();
29 const Level lvlRank = enc.getLvlRank();
34 auto sizeType = IntegerType::get(tp.getContext(), 64);
35 auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
36 auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
38 result.push_back(lvlSizes);
39 result.push_back(memSizes);
43 auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
44 auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
46 result.push_back(dimOffset);
47 result.push_back(dimStride);
53static Type convertSpecifier(StorageSpecifierType tp) {
54 return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55 getSpecifierFields(tp));
62constexpr uint64_t kLvlSizePosInSpecifier = 0;
63constexpr uint64_t kMemSizePosInSpecifier = 1;
64constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65constexpr uint64_t kDimStridePosInSpecifier = 3;
69 Value extractField(OpBuilder &builder, Location loc,
70 ArrayRef<int64_t>
indices)
const {
72 LLVM::ExtractValueOp::create(builder, loc, value,
indices),
76 void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t>
indices,
78 value = LLVM::InsertValueOp::create(
84 explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
89 static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
92 Value lvlSize(OpBuilder &builder, Location loc,
Level lvl)
const;
93 void setLvlSize(OpBuilder &builder, Location loc,
Level lvl, Value size);
95 Value dimOffset(OpBuilder &builder, Location loc,
Dimension dim)
const;
96 void setDimOffset(OpBuilder &builder, Location loc,
Dimension dim,
99 Value dimStride(OpBuilder &builder, Location loc,
Dimension dim)
const;
100 void setDimStride(OpBuilder &builder, Location loc,
Dimension dim,
103 Value memSize(OpBuilder &builder, Location loc,
FieldIndex fidx)
const;
104 void setMemSize(OpBuilder &builder, Location loc,
FieldIndex fidx,
107 Value memSizeArray(OpBuilder &builder, Location loc)
const;
108 void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
113 Value metaData = LLVM::PoisonOp::create(builder, loc, structType);
114 SpecifierStructBuilder md(metaData);
116 auto memSizeArrayType =
117 cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118 .getBody()[kMemSizePosInSpecifier]);
120 Value zero =
constantZero(builder, loc, memSizeArrayType.getElementType());
122 for (
int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123 md.setMemSize(builder, loc, i, zero);
126 SpecifierStructBuilder sourceMd(source);
127 md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
133Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
137 ArrayRef<int64_t>{kDimOffsetPosInSpecifier,
static_cast<int64_t
>(dim)});
141void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
145 ArrayRef<int64_t>{kDimOffsetPosInSpecifier,
static_cast<int64_t
>(dim)},
150Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
156 ArrayRef<int64_t>{kLvlSizePosInSpecifier,
static_cast<int64_t
>(lvl)});
160void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
161 Level lvl, Value size) {
166 ArrayRef<int64_t>{kLvlSizePosInSpecifier,
static_cast<int64_t
>(lvl)},
171Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
175 ArrayRef<int64_t>{kDimStridePosInSpecifier,
static_cast<int64_t
>(dim)});
179void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
183 ArrayRef<int64_t>{kDimStridePosInSpecifier,
static_cast<int64_t
>(dim)},
188Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
192 ArrayRef<int64_t>{kMemSizePosInSpecifier,
static_cast<int64_t
>(fidx)});
196void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
200 ArrayRef<int64_t>{kMemSizePosInSpecifier,
static_cast<int64_t
>(fidx)},
205Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
206 Location loc)
const {
207 return LLVM::ExtractValueOp::create(builder, loc, value,
208 kMemSizePosInSpecifier);
212void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
214 value = LLVM::InsertValueOp::create(builder, loc, value, array,
215 kMemSizePosInSpecifier);
225 addConversion([](
Type type) {
return type; });
226 addConversion(convertSpecifier);
233template <
typename Base,
typename SourceOp>
237 using OpConversionPattern<SourceOp>::OpConversionPattern;
241 ConversionPatternRewriter &rewriter)
const override {
242 SpecifierStructBuilder spec(adaptor.getSpecifier());
243 switch (op.getSpecifierKind()) {
244 case StorageSpecifierKind::LvlSize: {
245 Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
246 rewriter.replaceOp(op, v);
249 case StorageSpecifierKind::DimOffset: {
250 Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
251 rewriter.replaceOp(op, v);
254 case StorageSpecifierKind::DimStride: {
255 Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
256 rewriter.replaceOp(op, v);
259 case StorageSpecifierKind::CrdMemSize:
260 case StorageSpecifierKind::PosMemSize:
261 case StorageSpecifierKind::ValMemSize: {
262 auto enc = op.getSpecifier().getType().getEncoding();
264 std::optional<unsigned> lvl;
266 lvl = (*op.getLevel());
269 Value v = Base::onMemSize(rewriter, op, spec, idx);
270 rewriter.replaceOp(op, v);
274 llvm_unreachable(
"unrecognized specifer kind");
280 SetStorageSpecifierOp> {
281 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
284 SpecifierStructBuilder &spec,
Level lvl) {
285 spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
290 SpecifierStructBuilder &spec,
Dimension d) {
291 spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
296 SpecifierStructBuilder &spec,
Dimension d) {
297 spec.setDimStride(builder, op.getLoc(), d, op.getValue());
302 SpecifierStructBuilder &spec,
FieldIndex fidx) {
303 spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
310 GetStorageSpecifierOp> {
311 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
314 SpecifierStructBuilder &spec,
Level lvl) {
315 return spec.lvlSize(builder, op.getLoc(), lvl);
319 const SpecifierStructBuilder &spec,
Dimension d) {
320 return spec.dimOffset(builder, op.getLoc(), d);
324 const SpecifierStructBuilder &spec,
Dimension d) {
325 return spec.dimStride(builder, op.getLoc(), d);
329 SpecifierStructBuilder &spec,
FieldIndex fidx) {
330 return spec.memSize(builder, op.getLoc(), fidx);
335 :
public OpConversionPattern<StorageSpecifierInitOp> {
337 using OpConversionPattern::OpConversionPattern;
340 ConversionPatternRewriter &rewriter)
const override {
341 Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
343 op, SpecifierStructBuilder::getInitValue(
344 rewriter, op.getLoc(), llvmType, adaptor.getSource()));
LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SourceOp::Adaptor OpAdaptor
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
StorageSpecifierToLLVMTypeConverter()
Helper class to produce LLVM dialect operations extracting or inserting values to a struct.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Provides methods to access fields of a sparse tensor with the given encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
unsigned FieldIndex
The type of field indices.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind)
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc)
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
LogicalResult matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)