9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
15 #include "llvm/ADT/EnumeratedArray.h"
16 #include "llvm/ADT/STLForwardCompat.h"
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/ADT/StringMap.h"
21 namespace sparse_tensor {
35 const auto vk_ = llvm::to_underlying(vk);
36 return 0 <= vk_ && vk_ <= 2;
45 const auto vk_ =
static_cast<int_fast8_t
>(llvm::to_underlying(vk));
46 return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
55 using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
68 using Storage = unsigned;
75 static constexpr
Num kMaxNum =
83 [[nodiscard]]
static constexpr
bool isWF_Num(
Num n) {
return n <= kMaxNum; }
95 : data((static_cast<Storage>(n) << 2) |
96 static_cast<Storage>(
llvm::to_underlying(vk))) {
97 assert(
isWF(vk) &&
"unknown VarKind");
98 assert(
isWF_Num(n) &&
"Var::Num is too large");
105 static_assert(IsZeroCostAbstraction<Impl>);
127 template <
typename U>
128 constexpr
bool isa()
const;
129 template <
typename U>
130 constexpr U
cast()
const;
131 template <
typename U>
132 constexpr std::optional<U>
dyn_cast()
const;
134 std::string
str()
const;
135 void print(llvm::raw_ostream &os)
const;
139 static_assert(IsZeroCostAbstraction<Var>);
151 static_assert(IsZeroCostAbstraction<SymVar>);
163 static_assert(IsZeroCostAbstraction<DimVar>);
175 static_assert(IsZeroCostAbstraction<LvlVar>);
177 template <
typename U>
179 if constexpr (std::is_same_v<U, SymVar>)
181 if constexpr (std::is_same_v<U, DimVar>)
183 if constexpr (std::is_same_v<U, LvlVar>)
187 template <
typename U>
194 template <
typename U>
197 return isa<U>() ? std::make_optional(U(
impl)) : std::nullopt;
209 static constexpr
unsigned to_index(
VarKind vk) {
210 assert(
isWF(vk) &&
"unknown VarKind");
211 return static_cast<unsigned>(llvm::to_underlying(vk));
215 constexpr
Ranks(
unsigned symRank,
unsigned dimRank,
unsigned lvlRank)
238 static_assert(IsZeroCostAbstraction<Ranks>);
279 enum class ID : unsigned {};
285 std::optional<Var::Num> num;
290 std::optional<Var::Num> n = {})
291 : name(name), loc(loc), id(
id), num(n), kind(vk) {
292 assert(!name.empty() &&
"null StringRef");
293 assert(loc.isValid() &&
"null SMLoc");
294 assert(
isWF(vk) &&
"unknown VarKind");
295 assert((!n ||
Var::isWF_Num(*n)) &&
"Var::Num is too large");
298 constexpr StringRef
getName()
const {
return name; }
299 constexpr llvm::SMLoc
getLoc()
const {
return loc; }
305 constexpr std::optional<Var::Num>
getNum()
const {
return num; }
306 constexpr
bool hasNum()
const {
return num.has_value(); }
310 return Var(kind, *num);
324 llvm::StringMap<VarInfo::ID> ids;
340 return vars[llvm::to_underlying(
id)];
343 return oid ? &
access(*oid) :
nullptr;
348 return const_cast<VarInfo &
>(std::as_const(*this).access(
id));
350 VarInfo *
access(std::optional<VarInfo::ID> oid) {
351 return const_cast<VarInfo *
>(std::as_const(*this).access(oid));
356 std::optional<VarInfo::ID>
lookup(StringRef name)
const;
364 std::optional<std::pair<VarInfo::ID, bool>>
365 create(StringRef name, llvm::SMLoc loc,
VarKind vk,
bool verifyUsage =
false);
373 std::optional<std::pair<VarInfo::ID, bool>>
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
A symbolic identifier appearing in an affine expression.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
This base class exposes generic asm printer hooks, usable across the various derived printers.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
static constexpr bool classof(Var const *var)
static constexpr VarKind Kind
constexpr DimVar(Num dim)
DimVar(AffineDimExpr dimExpr)
constexpr LvlVar(Num lvl)
static constexpr VarKind Kind
static constexpr bool classof(Var const *var)
LvlVar(AffineDimExpr lvlExpr)
constexpr unsigned getRank(VarKind vk) const
constexpr unsigned getLvlRank() const
bool operator==(Ranks const &other) const
Ranks(VarKindArray< unsigned > const &ranks)
bool operator!=(Ranks const &other) const
constexpr unsigned getDimRank() const
constexpr unsigned getSymRank() const
constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
constexpr bool isValid(Var var) const
constexpr SymVar(Num sym)
SymVar(AffineSymbolExpr symExpr)
static constexpr VarKind Kind
static constexpr bool classof(Var const *var)
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
VarInfo const & access(VarInfo::ID id) const
Gets the underlying storage for the VarInfo identified by the VarInfo::ID.
Ranks getRanks() const
Returns the current ranks of bound variables.
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Var getVar(VarInfo::ID id) const
Gets the Var identified by the VarInfo::ID, raising an assertion failure if the variable is not bound...
std::optional< VarInfo::ID > lookup(StringRef name) const
Looks up the variable with the given name.
VarInfo const * access(std::optional< VarInfo::ID > oid) const
std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)
Creates a new currently-unbound variable.
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
A record of metadata for/about a variable, used by VarEnv.
constexpr VarKind getKind() const
constexpr ID getID() const
constexpr llvm::SMLoc getLoc() const
constexpr Var getVar() const
constexpr StringRef getName() const
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
constexpr bool hasNum() const
constexpr std::optional< Var::Num > getNum() const
Location getLocation(AsmParser &parser) const
constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, std::optional< Var::Num > n={})
Efficient representation of a set of Var.
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
unsigned getRank(VarKind vk) const
unsigned getDimRank() const
unsigned getLvlRank() const
unsigned getSymRank() const
VarSet(Ranks const &ranks)
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
The underlying implementation of Var.
constexpr Num getNum() const
constexpr bool operator!=(Impl other) const
constexpr VarKind getKind() const
constexpr Impl(VarKind vk, Num n)
constexpr bool operator==(Impl other) const
A concrete variable, to be used in our variant of AffineExpr.
constexpr Num getNum() const
constexpr Var(Impl impl)
Protected ctor for the RTTI methods to use.
constexpr std::optional< U > dyn_cast() const
Var(VarKind vk, AffineDimExpr var)
constexpr VarKind getKind() const
void print(llvm::raw_ostream &os) const
constexpr bool operator!=(Var other) const
constexpr Var(VarKind vk, Num n)
static constexpr bool isWF_Num(Num n)
Checks whether the number would be accepted by Var(VarKind,Var::Num).
constexpr bool isa() const
Var(AffineSymbolExpr sym)
constexpr bool operator==(Var other) const
unsigned Num
Typedef for the type of variable numbers.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr bool isWF(VarKind vk)
VarKind
The three kinds of variables that Var can be.
llvm::EnumeratedArray< T, VarKind, VarKind::Level > VarKindArray
The type of arrays indexed by VarKind.
constexpr char toChar(VarKind vk)
Gets the ASCII character used as the prefix when printing Var.
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
uint64_t Level
The type of level identifiers and level-ranks.
Include the generated interface declarations.