MLIR  19.0.0git
DimLvlMap.h
Go to the documentation of this file.
1 //===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
10 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
11 
12 #include "Var.h"
13 
15 #include "llvm/ADT/STLForwardCompat.h"
16 
17 namespace mlir {
18 namespace sparse_tensor {
19 namespace ir_detail {
20 
21 //===----------------------------------------------------------------------===//
22 enum class ExprKind : bool { Dimension = false, Level = true };
23 
25  using VK = std::underlying_type_t<VarKind>;
26  return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
27 }
30 
31 //===----------------------------------------------------------------------===//
32 class DimLvlExpr {
33 private:
34  ExprKind kind;
35  AffineExpr expr;
36 
37 public:
38  constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
39 
40  //
41  // Boolean operators.
42  //
43  constexpr bool operator==(DimLvlExpr other) const {
44  return kind == other.kind && expr == other.expr;
45  }
46  constexpr bool operator!=(DimLvlExpr other) const {
47  return !(*this == other);
48  }
49  explicit operator bool() const { return static_cast<bool>(expr); }
50 
51  //
52  // RTTI support (for the `DimLvlExpr` class itself).
53  //
54  template <typename U>
55  constexpr bool isa() const;
56  template <typename U>
57  constexpr U cast() const;
58  template <typename U>
59  constexpr U dyn_cast() const;
60 
61  //
62  // Simple getters.
63  //
64  constexpr ExprKind getExprKind() const { return kind; }
65  constexpr VarKind getAllowedVarKind() const {
66  return getVarKindAllowedInExpr(kind);
67  }
68  constexpr AffineExpr getAffineExpr() const { return expr; }
70  assert(expr);
71  return expr.getKind();
72  }
74  return expr ? expr.getContext() : nullptr;
75  }
76 
77  //
78  // Getters for handling `AffineExpr` subclasses.
79  //
80  SymVar castSymVar() const;
81  std::optional<SymVar> dyn_castSymVar() const;
82  Var castDimLvlVar() const;
83  std::optional<Var> dyn_castDimLvlVar() const;
84  std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
85 
86  /// Checks whether the variables bound/used by this spec are valid
87  /// with respect to the given ranks.
88  [[nodiscard]] bool isValid(Ranks const &ranks) const;
89 
90 protected:
91  // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
92  enum class BindingStrength : bool { Weak = false, Strong = true };
93 };
94 static_assert(IsZeroCostAbstraction<DimLvlExpr>);
95 
96 class DimExpr final : public DimLvlExpr {
97  friend class DimLvlExpr;
98  constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
99 
100 public:
101  static constexpr ExprKind Kind = ExprKind::Dimension;
102  static constexpr bool classof(DimLvlExpr const *expr) {
103  return expr->getExprKind() == Kind;
104  }
105  constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
106 
107  LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
108  std::optional<LvlVar> dyn_castLvlVar() const {
109  const auto var = dyn_castDimLvlVar();
110  return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
111  }
112 };
113 static_assert(IsZeroCostAbstraction<DimExpr>);
114 
115 class LvlExpr final : public DimLvlExpr {
116  friend class DimLvlExpr;
117  constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
118 
119 public:
120  static constexpr ExprKind Kind = ExprKind::Level;
121  static constexpr bool classof(DimLvlExpr const *expr) {
122  return expr->getExprKind() == Kind;
123  }
124  constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
125 
126  DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
127  std::optional<DimVar> dyn_castDimVar() const {
128  const auto var = dyn_castDimLvlVar();
129  return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
130  }
131 };
132 static_assert(IsZeroCostAbstraction<LvlExpr>);
133 
134 template <typename U>
135 constexpr bool DimLvlExpr::isa() const {
136  if constexpr (std::is_same_v<U, DimExpr>)
137  return getExprKind() == ExprKind::Dimension;
138  if constexpr (std::is_same_v<U, LvlExpr>)
139  return getExprKind() == ExprKind::Level;
140 }
141 
142 template <typename U>
143 constexpr U DimLvlExpr::cast() const {
144  assert(isa<U>());
145  return U(*this);
146 }
147 
148 template <typename U>
149 constexpr U DimLvlExpr::dyn_cast() const {
150  return isa<U>() ? U(*this) : U();
151 }
152 
153 //===----------------------------------------------------------------------===//
154 /// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
155 class DimSpec final {
156  /// The dimension-variable bound by this specification.
157  DimVar var;
158  /// The dimension-expression. The `DimSpec` ctor treats this field
159  /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
160  /// the expression via function-inversion inference.
161  DimExpr expr;
162  /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
163  /// not (though if `expr` is null it will elide printing that); whereas
164  /// the `DimLvlMap` ctor will reset it as appropriate.
165  bool elideExpr = false;
166  /// The dimension-slice; optional, default is null.
167  SparseTensorDimSliceAttr slice;
168 
169 public:
170  DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
171 
172  MLIRContext *tryGetContext() const { return expr.tryGetContext(); }
173 
174  constexpr DimVar getBoundVar() const { return var; }
175  bool hasExpr() const { return static_cast<bool>(expr); }
176  constexpr DimExpr getExpr() const { return expr; }
177  void setExpr(DimExpr newExpr) {
178  assert(!hasExpr());
179  expr = newExpr;
180  }
181  constexpr bool canElideExpr() const { return elideExpr; }
182  void setElideExpr(bool b) { elideExpr = b; }
183  constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
184 
185  /// Checks whether the variables bound/used by this spec are valid with
186  /// respect to the given ranks. Note that null `DimExpr` is considered
187  /// to be vacuously valid, and therefore calling `setExpr` invalidates
188  /// the result of this predicate.
189  [[nodiscard]] bool isValid(Ranks const &ranks) const;
190 };
191 
192 static_assert(IsZeroCostAbstraction<DimSpec>);
193 
194 //===----------------------------------------------------------------------===//
195 /// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
196 class LvlSpec final {
197  /// The level-variable bound by this specification.
198  LvlVar var;
199  /// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not;
200  /// whereas the `DimLvlMap` ctor will reset this as appropriate.
201  bool elideVar = false;
202  /// The level-expression.
203  LvlExpr expr;
204  /// The level-type (== level-format + lvl-properties).
205  LevelType type;
206 
207 public:
208  LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
209 
211  MLIRContext *ctx = expr.tryGetContext();
212  assert(ctx);
213  return ctx;
214  }
215 
216  constexpr LvlVar getBoundVar() const { return var; }
217  constexpr bool canElideVar() const { return elideVar; }
218  void setElideVar(bool b) { elideVar = b; }
219  constexpr LvlExpr getExpr() const { return expr; }
220  constexpr LevelType getType() const { return type; }
221 
222  /// Checks whether the variables bound/used by this spec are valid
223  /// with respect to the given ranks.
224  [[nodiscard]] bool isValid(Ranks const &ranks) const;
225 };
226 
227 static_assert(IsZeroCostAbstraction<LvlSpec>);
228 
229 //===----------------------------------------------------------------------===//
230 class DimLvlMap final {
231 public:
232  DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
233  ArrayRef<LvlSpec> lvlSpecs);
234 
235  unsigned getSymRank() const { return symRank; }
236  unsigned getDimRank() const { return dimSpecs.size(); }
237  unsigned getLvlRank() const { return lvlSpecs.size(); }
238  unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
239  Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
240 
241  ArrayRef<DimSpec> getDims() const { return dimSpecs; }
242  const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
243  SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
244  return getDim(dim).getSlice();
245  }
246 
247  ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
248  const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
249  LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
250 
251  AffineMap getDimToLvlMap(MLIRContext *context) const;
252  AffineMap getLvlToDimMap(MLIRContext *context) const;
253 
254 private:
255  /// Checks for integrity of variable-binding structure.
256  /// This is already called by the ctor.
257  [[nodiscard]] bool isWF() const;
258 
259  /// Helper function to call `DimSpec::setExpr` while asserting that
260  /// the invariant established by `DimLvlMap:isWF` is maintained.
261  /// This is used by the ctor.
262  void setDimExpr(Dimension dim, DimExpr expr) {
263  assert(expr && getRanks().isValid(expr));
264  dimSpecs[dim].setExpr(expr);
265  }
266 
267  // All these fields are const-after-ctor.
268  unsigned symRank;
269  SmallVector<DimSpec> dimSpecs;
270  SmallVector<LvlSpec> lvlSpecs;
271  bool mustPrintLvlVars;
272 };
273 
274 //===----------------------------------------------------------------------===//
275 
276 } // namespace ir_detail
277 } // namespace sparse_tensor
278 } // namespace mlir
279 
280 #endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
static constexpr ExprKind Kind
Definition: DimLvlMap.h:101
static constexpr bool classof(DimLvlExpr const *expr)
Definition: DimLvlMap.h:102
std::optional< LvlVar > dyn_castLvlVar() const
Definition: DimLvlMap.h:108
constexpr DimExpr(AffineExpr expr)
Definition: DimLvlMap.h:105
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
constexpr bool operator!=(DimLvlExpr other) const
Definition: DimLvlMap.h:46
std::optional< SymVar > dyn_castSymVar() const
Definition: DimLvlMap.cpp:23
AffineExprKind getAffineKind() const
Definition: DimLvlMap.h:69
constexpr AffineExpr getAffineExpr() const
Definition: DimLvlMap.h:68
constexpr ExprKind getExprKind() const
Definition: DimLvlMap.h:64
constexpr bool operator==(DimLvlExpr other) const
Definition: DimLvlMap.h:43
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
Definition: DimLvlMap.cpp:40
constexpr DimLvlExpr(ExprKind ek, AffineExpr expr)
Definition: DimLvlMap.h:38
constexpr VarKind getAllowedVarKind() const
Definition: DimLvlMap.h:65
std::optional< Var > dyn_castDimLvlVar() const
Definition: DimLvlMap.cpp:33
DimLvlMap(unsigned symRank, ArrayRef< DimSpec > dimSpecs, ArrayRef< LvlSpec > lvlSpecs)
Definition: DimLvlMap.cpp:80
SparseTensorDimSliceAttr getDimSlice(Dimension dim) const
Definition: DimLvlMap.h:243
AffineMap getDimToLvlMap(MLIRContext *context) const
Definition: DimLvlMap.cpp:118
ArrayRef< LvlSpec > getLvls() const
Definition: DimLvlMap.h:247
AffineMap getLvlToDimMap(MLIRContext *context) const
Definition: DimLvlMap.cpp:127
unsigned getRank(VarKind vk) const
Definition: DimLvlMap.h:238
const DimSpec & getDim(Dimension dim) const
Definition: DimLvlMap.h:242
ArrayRef< DimSpec > getDims() const
Definition: DimLvlMap.h:241
LevelType getLvlType(Level lvl) const
Definition: DimLvlMap.h:249
const LvlSpec & getLvl(Level lvl) const
Definition: DimLvlMap.h:248
The full dimVar = dimExpr : dimSlice specification for a given dimension.
Definition: DimLvlMap.h:155
constexpr DimExpr getExpr() const
Definition: DimLvlMap.h:176
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
Definition: DimLvlMap.cpp:55
constexpr bool canElideExpr() const
Definition: DimLvlMap.h:181
MLIRContext * tryGetContext() const
Definition: DimLvlMap.h:172
constexpr SparseTensorDimSliceAttr getSlice() const
Definition: DimLvlMap.h:183
DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
Definition: DimLvlMap.cpp:52
constexpr DimVar getBoundVar() const
Definition: DimLvlMap.h:174
constexpr LvlExpr(AffineExpr expr)
Definition: DimLvlMap.h:124
static constexpr bool classof(DimLvlExpr const *expr)
Definition: DimLvlMap.h:121
static constexpr ExprKind Kind
Definition: DimLvlMap.h:120
std::optional< DimVar > dyn_castDimVar() const
Definition: DimLvlMap.h:127
The full lvlVar = lvlExpr : lvlType specification for a given level.
Definition: DimLvlMap.h:196
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
Definition: DimLvlMap.cpp:71
LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
Definition: DimLvlMap.cpp:65
constexpr LvlExpr getExpr() const
Definition: DimLvlMap.h:219
constexpr bool canElideVar() const
Definition: DimLvlMap.h:217
MLIRContext * getContext() const
Definition: DimLvlMap.h:210
constexpr LvlVar getBoundVar() const
Definition: DimLvlMap.h:216
constexpr LevelType getType() const
Definition: DimLvlMap.h:220
constexpr unsigned getRank(VarKind vk) const
Definition: Var.h:228
A concrete variable, to be used in our variant of AffineExpr.
Definition: Var.h:61
constexpr U cast() const
Definition: Var.h:188
VarKind
The three kinds of variables that Var can be.
Definition: Var.h:32
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek)
Definition: DimLvlMap.h:24
uint64_t Dimension
The type of dimension identifiers and dimension-ranks.
Definition: SparseTensor.h:35
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
Include the generated interface declarations.
AffineExprKind
Definition: AffineExpr.h:41
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238