MLIR  19.0.0git
DimLvlMap.cpp
Go to the documentation of this file.
1 //===- DimLvlMap.cpp ------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DimLvlMap.h"
10 
11 using namespace mlir;
12 using namespace mlir::sparse_tensor;
13 using namespace mlir::sparse_tensor::ir_detail;
14 
15 //===----------------------------------------------------------------------===//
16 // `DimLvlExpr` implementation.
17 //===----------------------------------------------------------------------===//
18 
20  return SymVar(llvm::cast<AffineSymbolExpr>(expr));
21 }
22 
23 std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
24  if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
25  return SymVar(s);
26  return std::nullopt;
27 }
28 
30  return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(expr));
31 }
32 
33 std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
34  if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
35  return Var(getAllowedVarKind(), x);
36  return std::nullopt;
37 }
38 
39 std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
41  const auto ak = getAffineKind();
42  const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
43  const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
44  const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
45  return {lhs, ak, rhs};
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // `DimSpec` implementation.
50 //===----------------------------------------------------------------------===//
51 
52 DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
53  : var(var), expr(expr), slice(slice) {}
54 
55 bool DimSpec::isValid(Ranks const &ranks) const {
56  // Nothing in `slice` needs additional validation.
57  // We explicitly consider null-expr to be vacuously valid.
58  return ranks.isValid(var) && (!expr || ranks.isValid(expr));
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // `LvlSpec` implementation.
63 //===----------------------------------------------------------------------===//
64 
66  : var(var), expr(expr), type(type) {
67  assert(expr);
68  assert(isValidLT(type) && !isUndefLT(type));
69 }
70 
71 bool LvlSpec::isValid(Ranks const &ranks) const {
72  // Nothing in `type` needs additional validation.
73  return ranks.isValid(var) && ranks.isValid(expr);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // `DimLvlMap` implementation.
78 //===----------------------------------------------------------------------===//
79 
80 DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
81  ArrayRef<LvlSpec> lvlSpecs)
82  : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
83  mustPrintLvlVars(false) {
84  // First, check integrity of the variable-binding structure.
85  // NOTE: This establishes the invariant that calls to `VarSet::add`
86  // below cannot cause OOB errors.
87  assert(isWF());
88 
89  VarSet usedVars(getRanks());
90  for (const auto &dimSpec : dimSpecs)
91  if (!dimSpec.canElideExpr())
92  usedVars.add(dimSpec.getExpr());
93  for (auto &lvlSpec : this->lvlSpecs) {
94  // Is this LvlVar used in any overt expression?
95  const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
96  // This LvlVar can be elided iff it isn't overtly used.
97  lvlSpec.setElideVar(!isUsed);
98  // If any LvlVar cannot be elided, then must forward-declare all LvlVars.
99  mustPrintLvlVars = mustPrintLvlVars || isUsed;
100  }
101 }
102 
103 bool DimLvlMap::isWF() const {
104  const auto ranks = getRanks();
105  unsigned dimNum = 0;
106  for (const auto &dimSpec : dimSpecs)
107  if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks))
108  return false;
109  assert(dimNum == ranks.getDimRank());
110  unsigned lvlNum = 0;
111  for (const auto &lvlSpec : lvlSpecs)
112  if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks))
113  return false;
114  assert(lvlNum == ranks.getLvlRank());
115  return true;
116 }
117 
119  SmallVector<AffineExpr> lvlAffines;
120  lvlAffines.reserve(getLvlRank());
121  for (const auto &lvlSpec : lvlSpecs)
122  lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
123  auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
124  return map;
125 }
126 
128  SmallVector<AffineExpr> dimAffines;
129  dimAffines.reserve(getDimRank());
130  for (const auto &dimSpec : dimSpecs) {
131  auto expr = dimSpec.getExpr().getAffineExpr();
132  if (expr) {
133  dimAffines.push_back(expr);
134  }
135  }
136  auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
137  // If no lvlToDim map was passed in, returns a null AffineMap and infers it
138  // in SparseTensorEncodingAttr::parse.
139  if (dimAffines.empty())
140  return AffineMap();
141  return map;
142 }
143 
144 //===----------------------------------------------------------------------===//
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::optional< SymVar > dyn_castSymVar() const
Definition: DimLvlMap.cpp:23
AffineExprKind getAffineKind() const
Definition: DimLvlMap.h:69
std::tuple< DimLvlExpr, AffineExprKind, DimLvlExpr > unpackBinop() const
Definition: DimLvlMap.cpp:40
constexpr VarKind getAllowedVarKind() const
Definition: DimLvlMap.h:65
std::optional< Var > dyn_castDimLvlVar() const
Definition: DimLvlMap.cpp:33
DimLvlMap(unsigned symRank, ArrayRef< DimSpec > dimSpecs, ArrayRef< LvlSpec > lvlSpecs)
Definition: DimLvlMap.cpp:80
AffineMap getDimToLvlMap(MLIRContext *context) const
Definition: DimLvlMap.cpp:118
AffineMap getLvlToDimMap(MLIRContext *context) const
Definition: DimLvlMap.cpp:127
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
Definition: DimLvlMap.cpp:55
DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
Definition: DimLvlMap.cpp:52
bool isValid(Ranks const &ranks) const
Checks whether the variables bound/used by this spec are valid with respect to the given ranks.
Definition: DimLvlMap.cpp:71
LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
Definition: DimLvlMap.cpp:65
constexpr bool isValid(Var var) const
Definition: Var.h:233
Efficient representation of a set of Var.
Definition: Var.h:242
bool contains(Var var) const
For the contains method: if variables occurring in the method parameter are OOB for the VarSet,...
Definition: Var.cpp:77
void add(Var var)
For the add methods: OOB parameters cause undefined behavior.
Definition: Var.cpp:87
A concrete variable, to be used in our variant of AffineExpr.
Definition: Var.h:61
bool isUndefLT(LevelType lt)
Definition: Enums.h:408
bool isValidLT(LevelType lt)
Definition: Enums.h:429
Include the generated interface declarations.
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238