18 using namespace sparse_tensor;
28 auto enc = tp.getEncoding();
29 const Level lvlRank = enc.getLvlRank();
38 result.push_back(lvlSizes);
39 result.push_back(memSizes);
46 result.push_back(dimOffset);
47 result.push_back(dimStride);
53 static Type convertSpecifier(StorageSpecifierType tp) {
54 return LLVM::LLVMStructType::getLiteral(tp.getContext(),
55 getSpecifierFields(tp));
62 constexpr uint64_t kLvlSizePosInSpecifier = 0;
63 constexpr uint64_t kMemSizePosInSpecifier = 1;
64 constexpr uint64_t kDimOffsetPosInSpecifier = 2;
65 constexpr uint64_t kDimStridePosInSpecifier = 3;
72 builder.
create<LLVM::ExtractValueOp>(loc, value, indices),
78 value = builder.
create<LLVM::InsertValueOp>(
113 Value metaData = builder.
create<LLVM::UndefOp>(loc, structType);
114 SpecifierStructBuilder md(metaData);
116 auto memSizeArrayType =
117 cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
118 .getBody()[kMemSizePosInSpecifier]);
122 for (
int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
123 md.setMemSize(builder, loc, i, zero);
126 SpecifierStructBuilder sourceMd(source);
127 md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
207 return builder.
create<LLVM::ExtractValueOp>(loc, value,
208 kMemSizePosInSpecifier);
214 value = builder.
create<LLVM::InsertValueOp>(loc, value, array,
215 kMemSizePosInSpecifier);
225 addConversion([](
Type type) {
return type; });
226 addConversion(convertSpecifier);
233 template <
typename Base,
typename SourceOp>
242 SpecifierStructBuilder spec(adaptor.getSpecifier());
243 switch (op.getSpecifierKind()) {
244 case StorageSpecifierKind::LvlSize: {
245 Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
249 case StorageSpecifierKind::DimOffset: {
250 Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
254 case StorageSpecifierKind::DimStride: {
255 Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
259 case StorageSpecifierKind::CrdMemSize:
260 case StorageSpecifierKind::PosMemSize:
261 case StorageSpecifierKind::ValMemSize: {
262 auto enc = op.getSpecifier().getType().getEncoding();
264 std::optional<unsigned> lvl;
266 lvl = (*op.getLevel());
269 Value v = Base::onMemSize(rewriter, op, spec, idx);
274 llvm_unreachable(
"unrecognized specifer kind");
280 SetStorageSpecifierOp> {
281 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
284 SpecifierStructBuilder &spec,
Level lvl) {
285 spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
290 SpecifierStructBuilder &spec,
Dimension d) {
291 spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
296 SpecifierStructBuilder &spec,
Dimension d) {
297 spec.setDimStride(builder, op.getLoc(), d, op.getValue());
302 SpecifierStructBuilder &spec,
FieldIndex fidx) {
303 spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
310 GetStorageSpecifierOp> {
311 using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
314 SpecifierStructBuilder &spec,
Level lvl) {
315 return spec.lvlSize(builder, op.getLoc(), lvl);
319 const SpecifierStructBuilder &spec,
Dimension d) {
320 return spec.dimOffset(builder, op.getLoc(), d);
324 const SpecifierStructBuilder &spec,
Dimension d) {
325 return spec.dimStride(builder, op.getLoc(), d);
329 SpecifierStructBuilder &spec,
FieldIndex fidx) {
330 return spec.memSize(builder, op.getLoc(), fidx);
341 Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
343 op, SpecifierStructBuilder::getInitValue(
344 rewriter, op.getLoc(), llvmType, adaptor.getSource()));
LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
StorageSpecifierToLLVMTypeConverter()
Helper class to produce LLVM dialect operations extracting or inserting values to a struct.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Provides methods to access fields of a sparse tensor with the given encoding.
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional< Level > lvl) const
Gets the field index for required field.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
unsigned FieldIndex
The type of field indices.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind)
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc)
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
const TypeConverter & converter
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, const SpecifierStructBuilder &spec, Dimension d)
LogicalResult matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Level lvl)
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, Dimension d)
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, FieldIndex fidx)