MLIR 22.0.0git
SparseTensorDialect.cpp File Reference
#include <utility>
#include "Detail/DimLvlMapParser.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"

Go to the source code of this file.

Classes

struct  RemoveUnusedLvlCrds

Namespaces

namespace  mlir
 Include the generated interface declarations.
namespace  mlir::sparse_tensor

Macros

#define GET_ATTRDEF_CLASSES
#define GET_TYPEDEF_CLASSES
#define GET_ATTRDEF_LIST
#define GET_TYPEDEF_LIST
#define GET_OP_LIST
#define GET_OP_CLASSES

Functions

static mlir::ParseResult parseLevelRange (AsmParser &parser, Level &lvlLo, Level &lvlHi)
 Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static void printLevelRange (AsmPrinter &p, Level lo, Level hi)
 Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
llvm::hash_code mlir::sparse_tensor::hash_value (LevelType lt)
static constexpr bool acceptBitWidth (unsigned bitWidth)
static SmallVector< SizegetSparseFieldShape (const SparseTensorEncodingAttr enc, std::optional< ArrayRef< int64_t > > dimShape)
static ParseResult parseOptionalStaticSlice (int64_t &result, AsmParser &parser)
static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier (SparseTensorEncodingAttr enc)
 We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well as other variants) lead to the same storage specifier type, and stripping irrelevant fields that do not alter the sparse tensor memory layout.
static LogicalResult lvlIsInBounds (Level lvl, Value tensor)
static LogicalResult isMatchingWidth (Value mem, unsigned width)
static LogicalResult verifySparsifierGetterSetter (StorageSpecifierKind mdKind, std::optional< Level > lvl, TypedValue< StorageSpecifierType > md, Operation *op)
static Type getFieldElemType (SparseTensorType stt, SparseTensorFieldKind kind)
static LogicalResult verifyPackUnPack (Operation *op, bool requiresStaticShape, SparseTensorType stt, RankedTensorType valTp, TypeRange lvlTps)
template<typename ToBufferOp>
static LogicalResult inferSparseBufferType (ValueRange ops, DictionaryAttr attr, OpaqueProperties prop, RegionRange region, SmallVectorImpl< mlir::Type > &ret)
template<typename SpecifierOp>
static SetStorageSpecifierOp getSpecifierSetDef (SpecifierOp op)
template<class T>
static LogicalResult verifyNumBlockArgs (T *op, Region &region, const char *regionName, TypeRange inputTypes, Type outputType)
static ParseResult parseLevelRange (OpAsmParser &parser, IntegerAttr &lvlLoAttr, IntegerAttr &lvlHiAttr)
 Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static void printLevelRange (OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, IntegerAttr lvlHi)
 Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.
static ParseResult parseOptionalDefinedList (OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, SmallVectorImpl< OpAsmParser::Argument > &definedArgs, unsigned maxCnt=std::numeric_limits< unsigned >::max(), OpAsmParser::Delimiter delimiter=OpAsmParser::Delimiter::Paren)
 Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to annotate that the corresponding value is not defined (e.g., to represent an undefined coordinate in the sparse iteration space).
static void printOptionalDefinedList (OpAsmPrinter &p, unsigned size, Block::BlockArgListType blocksArgs, I64BitSet definedSet)
static ParseResult parseUsedCoordList (OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &coords)
static ParseResult parseSparseIterateLoop (OpAsmParser &parser, OperationState &state, SmallVectorImpl< OpAsmParser::Argument > &iterators, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static ParseResult parseSparseCoIterateLoop (OpAsmParser &parser, OperationState &state, SmallVectorImpl< Value > &spacesVals, SmallVectorImpl< OpAsmParser::Argument > &blockArgs)
static void printInitializationList (OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
 Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values.
template<typename SparseLoopOp>
static LogicalResult verifySparseLoopOp (SparseLoopOp op)

Variables

static constexpr Level kInvalidLevel = -1u
static constexpr Level kInvalidFieldIndex = -1u
static constexpr FieldIndex kDataFieldStartingIdx = 0

Macro Definition Documentation

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 29 of file SparseTensorDialect.cpp.

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 2797 of file SparseTensorDialect.cpp.

◆ GET_OP_LIST

#define GET_OP_LIST

◆ GET_TYPEDEF_CLASSES

#define GET_TYPEDEF_CLASSES

Definition at line 41 of file SparseTensorDialect.cpp.

◆ GET_TYPEDEF_LIST

#define GET_TYPEDEF_LIST

Function Documentation

◆ acceptBitWidth()

constexpr bool acceptBitWidth ( unsigned bitWidth)
staticconstexpr

Definition at line 59 of file SparseTensorDialect.cpp.

◆ getFieldElemType()

◆ getNormalizedEncodingForSpecifier()

SparseTensorEncodingAttr getNormalizedEncodingForSpecifier ( SparseTensorEncodingAttr enc)
static

We normalized sparse tensor encoding attribute by always using ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well as other variants) lead to the same storage specifier type, and stripping irrelevant fields that do not alter the sparse tensor memory layout.

Definition at line 1204 of file SparseTensorDialect.cpp.

References mlir::sparse_tensor::LevelType::stripStorageIrrelevantProperties().

◆ getSparseFieldShape()

SmallVector< Size > getSparseFieldShape ( const SparseTensorEncodingAttr enc,
std::optional< ArrayRef< int64_t > > dimShape )
static

◆ getSpecifierSetDef()

template<typename SpecifierOp>
SetStorageSpecifierOp getSpecifierSetDef ( SpecifierOp op)
static

Definition at line 1715 of file SparseTensorDialect.cpp.

◆ inferSparseBufferType()

◆ isMatchingWidth()

LogicalResult isMatchingWidth ( Value mem,
unsigned width )
static

◆ lvlIsInBounds()

LogicalResult lvlIsInBounds ( Level lvl,
Value tensor )
static

◆ parseLevelRange() [1/2]

ParseResult parseLevelRange ( mlir::AsmParser & parser,
mlir::sparse_tensor::Level & lvlLo,
mlir::sparse_tensor::Level & lvlHi )
static

Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.

Definition at line 2094 of file SparseTensorDialect.cpp.

References mlir::AsmParser::emitError(), mlir::AsmParser::getNameLoc(), mlir::AsmParser::parseInteger(), mlir::AsmParser::parseOptionalKeyword(), and success().

Referenced by parseLevelRange().

◆ parseLevelRange() [2/2]

ParseResult parseLevelRange ( OpAsmParser & parser,
IntegerAttr & lvlLoAttr,
IntegerAttr & lvlHiAttr )
static

Parses a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.

Definition at line 2115 of file SparseTensorDialect.cpp.

References mlir::AsmParser::getBuilder(), mlir::Builder::getIndexType(), parseLevelRange(), and success().

◆ parseOptionalDefinedList()

ParseResult parseOptionalDefinedList ( OpAsmParser & parser,
OperationState & state,
I64BitSet & definedSet,
SmallVectorImpl< OpAsmParser::Argument > & definedArgs,
unsigned maxCnt = std::numeric_limits<unsigned>::max(),
OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren )
static

Parses a list of optional defined list in the form of "(%val0, _, %val1, ...)", where _ is used to annotate that the corresponding value is not defined (e.g., to represent an undefined coordinate in the sparse iteration space).

Definition at line 2149 of file SparseTensorDialect.cpp.

References mlir::sparse_tensor::I64BitSet::count(), mlir::AsmParser::emitError(), mlir::AsmParser::getNameLoc(), mlir::AsmParser::Paren, mlir::OpAsmParser::parseArgument(), mlir::AsmParser::parseCommaSeparatedList(), mlir::AsmParser::parseOptionalKeyword(), mlir::sparse_tensor::I64BitSet::set(), and success().

Referenced by parseUsedCoordList().

◆ parseOptionalStaticSlice()

ParseResult parseOptionalStaticSlice ( int64_t & result,
AsmParser & parser )
static

◆ parseSparseCoIterateLoop()

◆ parseSparseIterateLoop()

◆ parseUsedCoordList()

◆ printInitializationList()

void printInitializationList ( OpAsmPrinter & p,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "" )
static

Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2, <...>) where 'inner' values are assumed to be region arguments and 'outer' values are regular SSA values.

Definition at line 2500 of file SparseTensorDialect.cpp.

◆ printLevelRange() [1/2]

Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.

Definition at line 2128 of file SparseTensorDialect.cpp.

Referenced by printLevelRange().

◆ printLevelRange() [2/2]

void printLevelRange ( OpAsmPrinter & p,
Operation * ,
IntegerAttr lvlLo,
IntegerAttr lvlHi )
static

Prints a level range in the form "$lo `to` $hi" or simply "$lo" if $hi - $lo = 1.

Definition at line 2138 of file SparseTensorDialect.cpp.

References printLevelRange().

◆ printOptionalDefinedList()

void printOptionalDefinedList ( OpAsmPrinter & p,
unsigned size,
Block::BlockArgListType blocksArgs,
I64BitSet definedSet )
static

Definition at line 2179 of file SparseTensorDialect.cpp.

References mlir::sparse_tensor::I64BitSet::empty().

◆ verifyNumBlockArgs()

template<class T>
LogicalResult verifyNumBlockArgs ( T * op,
Region & region,
const char * regionName,
TypeRange inputTypes,
Type outputType )
static

◆ verifyPackUnPack()

◆ verifySparseLoopOp()

template<typename SparseLoopOp>
LogicalResult verifySparseLoopOp ( SparseLoopOp op)
static

Definition at line 2517 of file SparseTensorDialect.cpp.

References success().

◆ verifySparsifierGetterSetter()

LogicalResult verifySparsifierGetterSetter ( StorageSpecifierKind mdKind,
std::optional< Level > lvl,
TypedValue< StorageSpecifierType > md,
Operation * op )
static

Definition at line 1249 of file SparseTensorDialect.cpp.

References mlir::Operation::emitError(), and success().

Variable Documentation

◆ kDataFieldStartingIdx

FieldIndex kDataFieldStartingIdx = 0
staticconstexpr

◆ kInvalidFieldIndex

Level kInvalidFieldIndex = -1u
staticconstexpr

◆ kInvalidLevel

Level kInvalidLevel = -1u
staticconstexpr