MLIR  19.0.0git
Var.h
Go to the documentation of this file.
1 //===- Var.h ----------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
11 
12 #include "TemplateExtras.h"
13 
15 #include "llvm/ADT/EnumeratedArray.h"
16 #include "llvm/ADT/STLForwardCompat.h"
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/ADT/StringMap.h"
19 
20 namespace mlir {
21 namespace sparse_tensor {
22 namespace ir_detail {
23 
24 //===----------------------------------------------------------------------===//
25 /// The three kinds of variables that `Var` can be.
26 ///
27 /// NOTE: The numerical values used to represent this enum should be
28 /// treated as an implementation detail, not as part of the API. In the
29 /// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
30 /// though that does not agree with the numerical ordering of the numerical
31 /// representation.
32 enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
33 
34 [[nodiscard]] constexpr bool isWF(VarKind vk) {
35  const auto vk_ = llvm::to_underlying(vk);
36  return 0 <= vk_ && vk_ <= 2;
37 }
38 
39 /// Gets the ASCII character used as the prefix when printing `Var`.
40 constexpr char toChar(VarKind vk) {
41  // If `isWF(vk)` then this computation's intermediate results are always
42  // in the range [-44..126] (where that lower bound is under worst-case
43  // rearranging of the expression); and `int_fast8_t` is the fastest type
44  // which can support that range without over-/underflow.
45  const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));
46  return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
47 }
48 static_assert(toChar(VarKind::Symbol) == 's' &&
49  toChar(VarKind::Dimension) == 'd' &&
50  toChar(VarKind::Level) == 'l');
51 
52 //===----------------------------------------------------------------------===//
53 /// The type of arrays indexed by `VarKind`.
54 template <typename T>
55 using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
56 
57 //===----------------------------------------------------------------------===//
58 /// A concrete variable, to be used in our variant of `AffineExpr`.
59 /// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
60 /// support for subclasses with a fixed `VarKind`.
61 class Var {
62 public:
63  /// Typedef for the type of variable numbers.
64  using Num = unsigned;
65 
66 private:
67  /// Typedef for the underlying storage of `Var::Impl`.
68  using Storage = unsigned;
69 
70  /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
71  /// Two low-order bits are reserved for storing the `VarKind`,
72  /// and one high-order bit is reserved for future use (e.g., to support
73  /// `DenseMapInfo<Var>` while maintaining the usual numeric values for
74  /// "empty" and "tombstone").
75  static constexpr Num kMaxNum =
76  static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);
77 
78 public:
79  /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
80  //
81  // This must be public for `VarInfo` to use it (whereas we don't want
82  // to expose the `impl` field via friendship).
83  [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
84 
85 protected:
86  /// The underlying implementation of `Var`. Note that this must be kept
87  /// distinct from `Var` itself, since we want to ensure that the RTTI
88  /// methods will select the `U(Var::Impl)` ctor rather than selecting
89  /// the `U(Var::Num)` ctor.
90  class Impl final {
91  Storage data;
92 
93  public:
94  constexpr Impl(VarKind vk, Num n)
95  : data((static_cast<Storage>(n) << 2) |
96  static_cast<Storage>(llvm::to_underlying(vk))) {
97  assert(isWF(vk) && "unknown VarKind");
98  assert(isWF_Num(n) && "Var::Num is too large");
99  }
100  constexpr bool operator==(Impl other) const { return data == other.data; }
101  constexpr bool operator!=(Impl other) const { return !(*this == other); }
102  constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
103  constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
104  };
105  static_assert(IsZeroCostAbstraction<Impl>);
106 
107 private:
108  Impl impl;
109 
110 protected:
111  /// Protected ctor for the RTTI methods to use.
112  constexpr explicit Var(Impl impl) : impl(impl) {}
113 
114 public:
115  constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
116  Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
117  Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
118  assert(vk != VarKind::Symbol);
119  }
120 
121  constexpr bool operator==(Var other) const { return impl == other.impl; }
122  constexpr bool operator!=(Var other) const { return !(*this == other); }
123 
124  constexpr VarKind getKind() const { return impl.getKind(); }
125  constexpr Num getNum() const { return impl.getNum(); }
126 
127  template <typename U>
128  constexpr bool isa() const;
129  template <typename U>
130  constexpr U cast() const;
131  template <typename U>
132  constexpr std::optional<U> dyn_cast() const;
133 
134  std::string str() const;
135  void print(llvm::raw_ostream &os) const;
136  void print(AsmPrinter &printer) const;
137  void dump() const;
138 };
139 static_assert(IsZeroCostAbstraction<Var>);
140 
141 class SymVar final : public Var {
142  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
143 public:
144  static constexpr VarKind Kind = VarKind::Symbol;
145  static constexpr bool classof(Var const *var) {
146  return var->getKind() == Kind;
147  }
148  constexpr SymVar(Num sym) : Var(Kind, sym) {}
149  SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
150 };
151 static_assert(IsZeroCostAbstraction<SymVar>);
152 
153 class DimVar final : public Var {
154  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
155 public:
156  static constexpr VarKind Kind = VarKind::Dimension;
157  static constexpr bool classof(Var const *var) {
158  return var->getKind() == Kind;
159  }
160  constexpr DimVar(Num dim) : Var(Kind, dim) {}
161  DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
162 };
163 static_assert(IsZeroCostAbstraction<DimVar>);
164 
165 class LvlVar final : public Var {
166  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
167 public:
168  static constexpr VarKind Kind = VarKind::Level;
169  static constexpr bool classof(Var const *var) {
170  return var->getKind() == Kind;
171  }
172  constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
173  LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
174 };
175 static_assert(IsZeroCostAbstraction<LvlVar>);
176 
177 template <typename U>
178 constexpr bool Var::isa() const {
179  if constexpr (std::is_same_v<U, SymVar>)
180  return getKind() == VarKind::Symbol;
181  if constexpr (std::is_same_v<U, DimVar>)
182  return getKind() == VarKind::Dimension;
183  if constexpr (std::is_same_v<U, LvlVar>)
184  return getKind() == VarKind::Level;
185 }
186 
187 template <typename U>
188 constexpr U Var::cast() const {
189  assert(isa<U>());
190  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
191  return U(impl);
192 }
193 
194 template <typename U>
195 constexpr std::optional<U> Var::dyn_cast() const {
196  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
197  return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
202 class DimLvlExpr;
203 
204 //===----------------------------------------------------------------------===//
205 class Ranks final {
206  // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
207  unsigned impl[3];
208 
209  static constexpr unsigned to_index(VarKind vk) {
210  assert(isWF(vk) && "unknown VarKind");
211  return static_cast<unsigned>(llvm::to_underlying(vk));
212  }
213 
214 public:
215  constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
216  : impl() {
217  impl[to_index(VarKind::Symbol)] = symRank;
218  impl[to_index(VarKind::Dimension)] = dimRank;
219  impl[to_index(VarKind::Level)] = lvlRank;
220  }
222  : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
223  ranks[VarKind::Level]) {}
224 
225  bool operator==(Ranks const &other) const;
226  bool operator!=(Ranks const &other) const { return !(*this == other); }
227 
228  constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
229  constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
230  constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
231  constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); }
232 
233  [[nodiscard]] constexpr bool isValid(Var var) const {
234  return var.getNum() < getRank(var.getKind());
235  }
236  [[nodiscard]] bool isValid(DimLvlExpr expr) const;
237 };
238 static_assert(IsZeroCostAbstraction<Ranks>);
239 
240 //===----------------------------------------------------------------------===//
241 /// Efficient representation of a set of `Var`.
242 class VarSet final {
244 
245 public:
246  explicit VarSet(Ranks const &ranks);
247 
248  unsigned getRank(VarKind vk) const { return impl[vk].size(); }
249  unsigned getSymRank() const { return getRank(VarKind::Symbol); }
250  unsigned getDimRank() const { return getRank(VarKind::Dimension); }
251  unsigned getLvlRank() const { return getRank(VarKind::Level); }
252  Ranks getRanks() const {
253  return Ranks(getSymRank(), getDimRank(), getLvlRank());
254  }
255  /// For the `contains` method: if variables occurring in
256  /// the method parameter are OOB for the `VarSet`, then these methods will
257  /// always return false.
258  bool contains(Var var) const;
259 
260  /// For the `add` methods: OOB parameters cause undefined behavior.
261  /// Currently the `add` methods will raise an assertion error.
262  void add(Var var);
263  void add(VarSet const &vars);
264  void add(DimLvlExpr expr);
265 };
266 
267 //===----------------------------------------------------------------------===//
268 /// A record of metadata for/about a variable, used by `VarEnv`.
269 /// The principal goal of this record is to enable `VarEnv` to be used for
270 /// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
271 /// remain unknown, since each record is instead identified by `VarInfo::ID`.
272 /// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
273 /// order it likes, irrespective of the binding order (`Var::Num`) of the
274 /// associated variable.
275 class VarInfo final {
276 public:
277  /// Newtype for unique identifiers of `VarInfo` records, to ensure
278  /// they aren't confused with `Var::Num`.
279  enum class ID : unsigned {};
280 
281 private:
282  StringRef name; // The bare-id used in the MLIR source.
283  llvm::SMLoc loc; // The location of the first occurence.
284  ID id; // The unique `VarInfo`-identifier.
285  std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
286  VarKind kind; // The kind of variable.
287 
288 public:
289  constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
290  std::optional<Var::Num> n = {})
291  : name(name), loc(loc), id(id), num(n), kind(vk) {
292  assert(!name.empty() && "null StringRef");
293  assert(loc.isValid() && "null SMLoc");
294  assert(isWF(vk) && "unknown VarKind");
295  assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
296  }
297 
298  constexpr StringRef getName() const { return name; }
299  constexpr llvm::SMLoc getLoc() const { return loc; }
300  Location getLocation(AsmParser &parser) const {
301  return parser.getEncodedSourceLoc(loc);
302  }
303  constexpr ID getID() const { return id; }
304  constexpr VarKind getKind() const { return kind; }
305  constexpr std::optional<Var::Num> getNum() const { return num; }
306  constexpr bool hasNum() const { return num.has_value(); }
307  void setNum(Var::Num n);
308  constexpr Var getVar() const {
309  assert(hasNum());
310  return Var(kind, *num);
311  }
312 };
313 
314 //===----------------------------------------------------------------------===//
315 enum class Policy { MustNot, May, Must };
316 
317 //===----------------------------------------------------------------------===//
318 class VarEnv final {
319  /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
320  VarKindArray<Var::Num> nextNum;
321  /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
323  /// Map from variable names to their `VarInfo::ID`.
324  llvm::StringMap<VarInfo::ID> ids;
325 
326  VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
327 
328 public:
329  VarEnv() : nextNum(0) {}
330 
331  /// Gets the underlying storage for the `VarInfo` identified by
332  /// the `VarInfo::ID`.
333  ///
334  /// NOTE: The returned reference can become dangling if the `VarEnv`
335  /// object is mutated during the lifetime of the pointer. Therefore,
336  /// client code should not store the reference nor otherwise allow it
337  /// to live too long.
338  VarInfo const &access(VarInfo::ID id) const {
339  // `SmallVector::operator[]` already asserts the index is in-bounds.
340  return vars[llvm::to_underlying(id)];
341  }
342  VarInfo const *access(std::optional<VarInfo::ID> oid) const {
343  return oid ? &access(*oid) : nullptr;
344  }
345 
346 private:
347  VarInfo &access(VarInfo::ID id) {
348  return const_cast<VarInfo &>(std::as_const(*this).access(id));
349  }
350  VarInfo *access(std::optional<VarInfo::ID> oid) {
351  return const_cast<VarInfo *>(std::as_const(*this).access(oid));
352  }
353 
354 public:
355  /// Looks up the variable with the given name.
356  std::optional<VarInfo::ID> lookup(StringRef name) const;
357 
358  /// Creates a new currently-unbound variable. When a variable
359  /// of that name already exists: if `verifyUsage` is true, then will assert
360  /// that the variable has the same kind and a consistent location; otherwise,
361  /// when `verifyUsage` is false, this is a noop. Returns the identifier
362  /// for the variable with the given name, and a bool indicating whether
363  /// a new variable was created.
364  std::optional<std::pair<VarInfo::ID, bool>>
365  create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
366 
367  /// Looks up or creates a variable according to the given
368  /// `Policy`. Returns nullopt in one of two circumstances:
369  /// (1) the policy says we `Must` create, yet the variable already exists;
370  /// (2) the policy says we `MustNot` create, yet no such variable exists.
371  /// Otherwise, if the variable already exists then it is validated against
372  /// the given kind and location to ensure consistency.
373  std::optional<std::pair<VarInfo::ID, bool>>
374  lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
375  VarKind vk);
376 
377  /// Binds the given variable to the next free `Var::Num` for its `VarKind`.
378  Var bindVar(VarInfo::ID id);
379 
380  /// Creates a new variable of the given kind and immediately binds it.
381  /// This should only be used whenever the variable is known to be unused
382  /// and therefore does not have a name.
383  Var bindUnusedVar(VarKind vk);
384 
386 
387  /// Returns the current ranks of bound variables. This method should
388  /// only be used after the environment is "finished", since binding new
389  /// variables will (semantically) invalidate any previously returned `Ranks`.
390  Ranks getRanks() const { return Ranks(nextNum); }
391 
392  /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
393  /// failure if the variable is not bound.
394  Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
395 };
396 
397 //===----------------------------------------------------------------------===//
398 
399 } // namespace ir_detail
400 } // namespace sparse_tensor
401 } // namespace mlir
402 
403 #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:245
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
This base class exposes generic asm printer hooks, usable across the various derived printers.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
static constexpr bool classof(Var const *var)
Definition: Var.h:157
static constexpr VarKind Kind
Definition: Var.h:156
constexpr DimVar(Num dim)
Definition: Var.h:160
DimVar(AffineDimExpr dimExpr)
Definition: Var.h:161
constexpr LvlVar(Num lvl)
Definition: Var.h:172
static constexpr VarKind Kind
Definition: Var.h:168
static constexpr bool classof(Var const *var)
Definition: Var.h:169
LvlVar(AffineDimExpr lvlExpr)
Definition: Var.h:173
constexpr unsigned getRank(VarKind vk) const
Definition: Var.h:228
constexpr unsigned getLvlRank() const
Definition: Var.h:231
bool operator==(Ranks const &other) const
Definition: Var.cpp:50
Ranks(VarKindArray< unsigned > const &ranks)
Definition: Var.h:221
bool operator!=(Ranks const &other) const
Definition: Var.h:226
constexpr unsigned getDimRank() const
Definition: Var.h:230
constexpr unsigned getSymRank() const
Definition: Var.h:229
constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
Definition: Var.h:215
constexpr bool isValid(Var var) const
Definition: Var.h:233
constexpr SymVar(Num sym)
Definition: Var.h:148
SymVar(AffineSymbolExpr symExpr)
Definition: Var.h:149
static constexpr VarKind Kind
Definition: Var.h:144
static constexpr bool classof(Var const *var)
Definition: Var.h:145
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
Definition: Var.cpp:221
VarInfo const & access(VarInfo::ID id) const
Gets the underlying storage for the VarInfo identified by the VarInfo::ID.
Definition: Var.h:338
Ranks getRanks() const
Returns the current ranks of bound variables.
Definition: Var.h:390
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
Definition: Var.cpp:229
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Definition: Var.cpp:197
Var getVar(VarInfo::ID id) const
Gets the Var identified by the VarInfo::ID, raising an assertion failure if the variable is not bound...
Definition: Var.h:394
std::optional< VarInfo::ID > lookup(StringRef name) const
Looks up the variable with the given name.
Definition: Var.cpp:170
VarInfo const * access(std::optional< VarInfo::ID > oid) const
Definition: Var.h:342
std::optional< std::pair< VarInfo::ID, bool > > create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage=false)
Creates a new currently-unbound variable.
Definition: Var.cpp:181
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
Definition: Var.cpp:222
A record of metadata for/about a variable, used by VarEnv.
Definition: Var.h:275
constexpr VarKind getKind() const
Definition: Var.h:304
constexpr ID getID() const
Definition: Var.h:303
constexpr llvm::SMLoc getLoc() const
Definition: Var.h:299
constexpr Var getVar() const
Definition: Var.h:308
constexpr StringRef getName() const
Definition: Var.h:298
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
Definition: Var.h:279
constexpr bool hasNum() const
Definition: Var.h:306
constexpr std::optional< Var::Num > getNum() const
Definition: Var.h:305
Location getLocation(AsmParser &parser) const
Definition: Var.h:300
constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk, std::optional< Var::Num > n={})
Definition: Var.h:289
Efficient representation of a set of Var.
Definition: Var.h:242
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
Definition: Var.cpp:77
unsigned getRank(VarKind vk) const
Definition: Var.h:248
VarSet(Ranks const &ranks)
Definition: Var.cpp:71
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
Definition: Var.cpp:87
The underlying implementation of Var.
Definition: Var.h:90
constexpr Num getNum() const
Definition: Var.h:103
constexpr bool operator!=(Impl other) const
Definition: Var.h:101
constexpr VarKind getKind() const
Definition: Var.h:102
constexpr Impl(VarKind vk, Num n)
Definition: Var.h:94
constexpr bool operator==(Impl other) const
Definition: Var.h:100
A concrete variable, to be used in our variant of AffineExpr.
Definition: Var.h:61
constexpr Num getNum() const
Definition: Var.h:125
constexpr Var(Impl impl)
Protected ctor for the RTTI methods to use.
Definition: Var.h:112
constexpr std::optional< U > dyn_cast() const
Definition: Var.h:195
Var(VarKind vk, AffineDimExpr var)
Definition: Var.h:117
constexpr VarKind getKind() const
Definition: Var.h:124
void print(llvm::raw_ostream &os) const
Definition: Var.cpp:37
constexpr bool operator!=(Var other) const
Definition: Var.h:122
constexpr U cast() const
Definition: Var.h:188
constexpr Var(VarKind vk, Num n)
Definition: Var.h:115
std::string str() const
Definition: Var.cpp:28
static constexpr bool isWF_Num(Num n)
Checks whether the number would be accepted by Var(VarKind,Var::Num).
Definition: Var.h:83
constexpr bool isa() const
Definition: Var.h:178
Var(AffineSymbolExpr sym)
Definition: Var.h:116
constexpr bool operator==(Var other) const
Definition: Var.h:121
unsigned Num
Typedef for the type of variable numbers.
Definition: Var.h:64
Include the generated interface declarations.
Definition: CallGraph.h:229
constexpr bool isWF(VarKind vk)
Definition: Var.h:34
VarKind
The three kinds of variables that Var can be.
Definition: Var.h:32
llvm::EnumeratedArray< T, VarKind, VarKind::Level > VarKindArray
The type of arrays indexed by VarKind.
Definition: Var.h:55
constexpr char toChar(VarKind vk)
Gets the ASCII character used as the prefix when printing Var.
Definition: Var.h:40
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
Include the generated interface declarations.